In [12]:
import os, gc, random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import librosa

from tqdm.notebook import tqdm
from glob import glob

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import torch.nn.functional as F

import transformers
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel

from sklearn.model_selection import GroupKFold
from sklearn.metrics import roc_auc_score

from time import time

import wandb

In [2]:
DRIVE_FOLDER = "." #"/content/drive/MyDrive/Colab Notebooks"
KEEP_COLS = ['category_number', 'common_name', 'audio_length', 'type', 'remarks', 'quality', 'scientific_name', 'mp3_link', 'region']

class Config:
    dataset_dir = f"{DRIVE_FOLDER}/Audio_XenoCanto"
    labels_list = f"{DRIVE_FOLDER}/xeno_labels.csv"
    model_name = "BirdAST_Baseline_5folds"
    backbone_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
    n_classes = 728 # number of classes in the dataset
    audio_sr = 16000 #Hz
    segment_length = 10  #s
    fft_window = 0.025 #s
    hop_window_length = 0.01 #s
    n_mels = 128
    low_cut = 1000 #Hz
    high_cut = 8000 #Hz
    top_db = 100
    batch_size = 16 #4 
    num_workers = 0
    n_splits = 5
    log_dir = f"{DRIVE_FOLDER}/training_logs"
    max_lr = 1e-5
    epochs = 15
    weight_decay = 0.01
    lr_final_div = 1000
    amp = True
    grad_accum_steps = 1
    max_grad_norm = 1e7
    print_epoch_freq = 1
    print_freq = 500
    random_seed = 2046
    
    @classmethod
    def copy(cls):
        new_class = type('CustomConfig', (cls,), {k: v for k, v in cls.__dict__.items() if not k.startswith('__') and not callable(v)})
        return new_class
    
config = Config.copy()

if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
df_audio_meta = pd.read_csv(f"{config.dataset_dir}/metadata.csv", nrows=None)
df_audio_meta = df_audio_meta.dropna().reset_index(drop=True)

# Filter out files that do not exist
df_audio_meta['file_exists'] = df_audio_meta['file_name'].apply(lambda x: os.path.exists(f"{config.dataset_dir}/{x}"))
df_audio_meta = df_audio_meta[df_audio_meta['file_exists']].reset_index(drop=True)

# parse scientific names
df_audio_meta['scientific_name'] = df_audio_meta['scientific_name'].apply(lambda x: "_".join(x.split(" ")))

# drop species with less than 2 samples
class_counts = df_audio_meta['scientific_name'].value_counts()
print(f"Number of classes with less than 2 samples: {len(class_counts[class_counts < 2])}")

df_audio_meta = df_audio_meta[df_audio_meta['scientific_name'].isin(class_counts[class_counts > 1].index)].copy().reset_index(drop=True)

# encode scientific names to label ids
label_ids_list = df_audio_meta['scientific_name'].unique().tolist()
label_ids_list.sort()
label_to_id = {label: i for i, label in enumerate(label_ids_list)}
df_audio_meta['species_id'] = df_audio_meta['scientific_name'].map(label_to_id)

# # save the label mapping
# label_mapping = pd.DataFrame(label_to_id.items(), columns=['scientific_name', 'species_id'])
# label_mapping.to_csv(f"{config.log_dir}/{config.model_name}_label_map.csv", index=False)

group_ids = df_audio_meta['mp3_link'].unique().tolist()
group_ids_map = {group_id: i for i, group_id in enumerate(group_ids)}

df_audio_meta['group_id'] = df_audio_meta['mp3_link'].map(group_ids_map)

# drop samples with no labels
df_audio_meta.dropna(subset=['species_id'], inplace=True)
df_audio_meta.reset_index(drop=True, inplace=True)
df_audio_meta['species_id'] = df_audio_meta['species_id'].astype(int)

print(f"Number of classes in dataset: {df_audio_meta['species_id'].nunique()}")
print(f'Number of samples:', len(df_audio_meta))

# save the number of classes in the config
config.n_classes = df_audio_meta['species_id'].nunique()

df_audio_meta.head(5)

Number of classes with less than 2 samples: 72
Number of classes in dataset: 728
Number of samples: 11171


Unnamed: 0,file_name,category_number,common_name,audio_length,type,remarks,quality,mp3_link,scientific_name,region,file_exists,species_id
0,data/XC228210-Blue-crowned_Manakin_B_9369_0.wav,XC228210,Blue-crowned Manakin,0:20,call,ID certainty 80%. (Archiv. tape 393 side A tra...,B,//xeno-canto.org/sounds/uploaded/OOECIWCSWV/XC...,Lepidothrix_coronata,amazonas,True,329
1,data/XC228210-Blue-crowned_Manakin_B_9369_1.wav,XC228210,Blue-crowned Manakin,0:20,call,ID certainty 80%. (Archiv. tape 393 side A tra...,B,//xeno-canto.org/sounds/uploaded/OOECIWCSWV/XC...,Lepidothrix_coronata,amazonas,True,329
2,data/XC200163-PIPCOR03_0.wav,XC200163,Blue-crowned Manakin,0:42,"call, song","left bank of rio Negro - terra firme forest, w...",C,//xeno-canto.org/sounds/uploaded/DGVLLRYDXS/XC...,Lepidothrix_coronata,amazonas,True,329
3,data/XC200163-PIPCOR03_1.wav,XC200163,Blue-crowned Manakin,0:42,"call, song","left bank of rio Negro - terra firme forest, w...",C,//xeno-canto.org/sounds/uploaded/DGVLLRYDXS/XC...,Lepidothrix_coronata,amazonas,True,329
4,data/XC200163-PIPCOR03_2.wav,XC200163,Blue-crowned Manakin,0:42,"call, song","left bank of rio Negro - terra firme forest, w...",C,//xeno-canto.org/sounds/uploaded/DGVLLRYDXS/XC...,Lepidothrix_coronata,amazonas,True,329


In [10]:
df_audio_meta[df_audio_meta['species_id'] == 220][['file_name', 'mp3_link']].iloc[0].tolist()

['data/XC649601-furnariusfigulus2_0.wav',
 '//xeno-canto.org/sounds/uploaded/DFYFSPXYGT/XC649601-furnariusfigulus2.mp3']

In [14]:
group_ids = df_audio_meta['mp3_link'].unique().tolist()
group_ids_map = {group_id: i for i, group_id in enumerate(group_ids)}

df_audio_meta

In [15]:
gkf = GroupKFold(n_splits=config.n_splits)
df_audio_meta['fold'] = 0

for fold, (train_index, test_index) in enumerate(gkf.split(df_audio_meta, groups=df_audio_meta['group_id'])):
    df_audio_meta.loc[test_index, 'fold'] = fold
    
df_audio_meta['fold'].value_counts()

fold
0    2235
1    2234
4    2234
3    2234
2    2234
Name: count, dtype: int64

In [16]:
# check if any sample in valid set has the same mp3_link as the train set
for fold in range(config.n_splits):
    train_group_ids = df_audio_meta[df_audio_meta['fold'] != fold]['mp3_link'].unique()
    valid_group_ids = df_audio_meta[df_audio_meta['fold'] == fold]['mp3_link'].unique()
    common_groups = set(train_group_ids).intersection(set(valid_group_ids))
    print(f"Fold {fold}: Number of common groups between train and valid: {len(common_groups)}")

Fold 0: Number of common groups between train and valid: 0
Fold 1: Number of common groups between train and valid: 0
Fold 2: Number of common groups between train and valid: 0
Fold 3: Number of common groups between train and valid: 0
Fold 4: Number of common groups between train and valid: 0


In [4]:
class BirdSongDataset(Dataset):
    
    def __init__(self, df_audio_meta, config):
        self.df_audio_meta = df_audio_meta
        self.feature_extractor = ASTFeatureExtractor()
        self.config = config
    
    def __len__(self):
        return len(self.df_audio_meta)

    def __getitem__(self, idx):
        row = self.df_audio_meta.iloc[idx]
        audio_path = f"{self.config.dataset_dir}/{row['file_name']}"
        audio_arr, sr = self.get_audio(audio_path)
        spec = self.feature_extractor(audio_arr, sampling_rate=sr, padding="max_length", return_tensors="pt")
        return spec['input_values'].squeeze(0), row['species_id']

    def get_audio(self, audio_path):
        audio, sr = librosa.load(audio_path, sr=self.config.audio_sr)
        return audio, sr

def collate_fn(batch):
    inputs = [x[0] for x in batch]
    targets = [x[1] for x in batch]
    data_dict = {
        "input_ids": torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0),
        "labels": torch.tensor(targets)
    }
    return data_dict

In [12]:
class BirdAST(nn.Module):
    
    def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
        super(BirdAST, self).__init__()
        
        # pre-trained backbone
        backbone_config = ASTConfig.from_pretrained(backbone_name)
        self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
        self.hidden_size = backbone_config.hidden_size
        
        # set activation functions
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        else:
            raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.")
        
        # define MLP layers with activation
        layers = []
        for _ in range(n_mlp_layers):
            layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            layers.append(self.activation)
        layers.append(nn.Linear(self.hidden_size, n_classes))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, spectrogram):
        # spectrogram: (batch_size, n_mels, n_frames)
        # output: (batch_size, n_classes)
        
        ast_output = self.ast(spectrogram, output_hidden_states=False)
        logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token 
        
        return {'logits': logits}


In [16]:
model_files = [
    f"/workspace/voice_of_jungle/training_logs/BirdAST_Baseline_5folds_fold_{i}.pth"
    for i in range(5)
]

df_label_mapping = pd.read_csv(f"{config.log_dir}/{config.model_name}_label_map.csv")
id_to_label = {row['species_id']: row['scientific_name'] for idx, row in df_label_mapping.iterrows()}

model = BirdAST(config.backbone_name, config.n_classes, n_mlp_layers=1, activation='silu')

bird_dataset = BirdSongDataset(df_audio_meta, config)
bird_loader = DataLoader(bird_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

with torch.no_grad():
    for idx, data in enumerate(bird_loader):
        print(f"Sample {idx}: True label: {data['labels'].item()} | {id_to_label[data['labels'].item()]}")
        
        for i, model_file in enumerate(model_files):
            model.load_state_dict(torch.load(model_file, map_location='cpu'))
            model.eval()
            logits = model(data['input_ids'])
            pred = logits['logits'].argmax(-1).item()
            print(f"Model {i}: , Predicted label: {pred}")
        if idx > 10:
            break
        print("\n")




Sample 0: True label: 329 | Lepidothrix_coronata
Model 0: , Predicted label: 673
Model 1: , Predicted label: 279
Model 2: , Predicted label: 279
Model 3: , Predicted label: 638
Model 4: , Predicted label: 493


Sample 1: True label: 329 | Lepidothrix_coronata
Model 0: , Predicted label: 307
Model 1: , Predicted label: 279
Model 2: , Predicted label: 279
Model 3: , Predicted label: 493
Model 4: , Predicted label: 304


Sample 2: True label: 329 | Lepidothrix_coronata
Model 0: , Predicted label: 291
Model 1: , Predicted label: 537
Model 2: , Predicted label: 521
Model 3: , Predicted label: 537
Model 4: , Predicted label: 214


Sample 3: True label: 329 | Lepidothrix_coronata
Model 0: , Predicted label: 497
Model 1: , Predicted label: 279
Model 2: , Predicted label: 327
Model 3: , Predicted label: 317
Model 4: , Predicted label: 284


Sample 4: True label: 329 | Lepidothrix_coronata
Model 0: , Predicted label: 453
Model 1: , Predicted label: 537
Model 2: , Predicted label: 642
Model 3: , 