In [None]:
# import os
import gc
import timm
import numpy as np
import pandas as pd
import sklearn.metrics
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts #, ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint , EarlyStopping

import wandb
import albumentations as A
from torchtoolbox.tools import mixup_data, mixup_criterion
import soundfile as sf
from sklearn import *
import pickle

import warnings
warnings.filterwarnings('ignore')


In [None]:

# Define a config dictionary object
config = {
    'use_aug': False,
    'num_classes': 264,
    'PRECISION': 16,    
    'PATIENCE': 10,    
    'seed': 2023,
    'pretrained': True,            
    'weight_decay': 1e-3,   
    'DEVICE': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),    
    'data_root': './',   
    'SR': 32000,
    'DURATION': 5,
    'MAX_READ_SAMPLES': 5,
    'save_path': './exp1/',
    'bird_name_path': 'bird_names.pickle3',
    'pickle_file_path': 'train_mel(dB,sr=32k,bin=128).pickle3',
    'random_state':42,
    'test_size':0.2,
    'cross_val_data_index':0,
    'LR': 5e-4,
    'use_mixup': True,
    'mixup_alpha': 0.2,
    'batch_size': 64,
    'epochs': 12,
    'model': 'tf_efficientnet_b2_ns',
    'exp_name':'efficientnet_b2'
}

# Pass the config dictionary when you initialize W&B
wandb.init(project="Bird2023", config=config, name='base')

In [None]:
pl.seed_everything(wandb.config['seed'], workers=True)

In [None]:
def load_pickle(fname):
    f = open(fname, 'rb')
    out = pickle.load(f)
    f.close()
    return out


In [None]:
try:
    data = load_pickle(wandb.config['pickle_file_path'])  # see ogg2mel for how to convert ogg to mel_dict.pickle3
    bird_names, _ = np.unique(data['primary_label'], return_counts=True)
except FileNotFoundError:
    print("dataset not found! you can generate one by using ogg2mel.py")
    data=None
    exit(-1)

In [None]:
class MyDataset(Dataset):
    def __init__(self, data, mode, cross_val_data_index=0, random_state=42, test_size=0.1):
        assert mode in ['train', 'test', 'val'], f'invalid mode {mode}!, mode must be [train | val | test]'
        self.mode = mode

        self.sr = 32000
        self.duration = 5
        self.audio_length = self.duration * self.sr

        self.name_label_2_int_label = load_pickle("name_label_2_int_label.pickle3")  # a dict which saves mapping from

        self.data = data
        self.cross_val_data_index = cross_val_data_index
        self.random_state = random_state
        self.test_size = test_size
        self.split_data()

    def __getitem__(self, index):
        mel = self.mel[index]         
        
        # conduct data augmentation
        mel = self.crop_or_pad(mel)
        mel = self.normalize(mel)
        mel = torch.from_numpy(mel).unsqueeze(0).float()
        mel = mel.repeat(3,1,1)

        label = self.name_label_2_int_label[self.label[index]]  # 'bird name' -> idx
        label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=264).float()
        
        return mel, label



    def __len__(self):
        return len(self.label)
        
    def crop_or_pad(self, m, th=313):  # 313=5s*32000Hz/512
        length = m.shape[1]
        if length <= th: # pad short
            while m.shape[1] < th:  # repeat padding until th
                m = np.concatenate([m, m],axis=1)
            m = m[:,0:th]
        else:  # crop longer audio
            start = np.random.randint(length - th)
            m = m[:,start:start+th]
        return m #torch.from_numpy(m).unsqueeze(0)

    def normalize(self, X, eps=1e-6, mean=None, std=None):
        mean = mean or X.mean()
        std = std or X.std()
        X = (X - mean) / (std + eps)
        
        _min, _max = X.min(), X.max()

        if (_max - _min) > eps:
            V = np.clip(X, _min, _max)
            V = (V - _min) / (_max - _min)
        else:
            V = np.zeros_like(X)

        return V
    
    def split_data(self):
        train_mels, val_mels, train_mel_tags, val_mel_tags = model_selection.train_test_split(self.data['mel'], self.data['primary_label'], test_size=self.test_size, random_state=self.random_state)

        coun = pd.value_counts(self.data['primary_label'])
        count_lis = dict(zip(coun.index.values.tolist(),coun.values.tolist()))

        for key in count_lis:
            if key not in train_mel_tags:
                index_outlier_in_test = val_mel_tags.index(key)
                value = val_mels.pop(index_outlier_in_test)
                class_name = val_mel_tags.pop(index_outlier_in_test)
                
                # append the value and class to train data
                train_mel_tags.append(class_name)
                train_mels.append(value)
        
        if self.mode=='train':
            self.mel = train_mels
            self.label = train_mel_tags
        elif self.mode == 'val':
            self.mel = val_mels
            self.label = val_mel_tags


In [None]:

def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=wandb.config['weight_decay']
        )
    interval = "epoch"
    
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=wandb.config['epochs'], 
                            T_mult=1, 
                            eta_min=1e-6, 
                            last_epoch=-1
                        )

    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

In [None]:
class BirdClefModel(pl.LightningModule):
    def __init__(self, model_name=wandb.config['model'], num_classes = wandb.config['num_classes'], pretrained = wandb.config['pretrained']):
        super().__init__()
        self.num_classes = num_classes
        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        if 'efficientnet' or 'mobilenet' in model_name:
            self.in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Sequential(
                nn.Linear(self.in_features, num_classes),
                nn.Dropout(0.4),
                nn.Linear(self.num_classes, num_classes))
        elif 'convnext' in model_name:
            self.in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Sequential(
                nn.Linear(self.in_features, num_classes),
                nn.Dropout(0.4),
                nn.Linear(self.num_classes, num_classes))
        else:
            raise('No valid model name')
                
        self.loss_function = nn.BCEWithLogitsLoss()
        self.global_acc = 0 

    def forward(self,images):
        logits = self.backbone(images)
        return logits
        
    def configure_optimizers(self):
        return get_optimizer(lr=wandb.config['LR'], params=self.parameters())

    def train_with_mixup(self, X, y):
        X, y_a, y_b, lam = mixup_data(X, y, alpha=wandb.config['mixup_alpha'])
        y_pred = self(X)
        loss_mixup = mixup_criterion(F.cross_entropy, y_pred, y_a, y_b, lam)
        return loss_mixup

    def training_step(self, batch, batch_idx):
        image, target = batch        
        if wandb.config['use_mixup']:
            loss = self.train_with_mixup(image, target)
        else:
            y_pred = self(image)
            loss = self.loss_function(y_pred,target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss        

    def validation_step(self, batch, batch_idx):
        image, target = batch     
        y_pred = self(image)
        val_loss = self.loss_function(y_pred, target)
        
        self.validation_step_outputs = {"val_loss": val_loss, "logits": y_pred, "targets": target}
        self.log("val_loss", val_loss)
        
        return self.validation_step_outputs
    
    # def test_step(self, batch, batch_idx):
    #     image, target = batch     
    #     y_pred = self(image)
    #     metrics = sklearn.metrics.label_ranking_average_precision_score(target.cpu().detach().numpy(),y_pred.sigmoid().cpu().detach().numpy())
    #     self.log('best_performance', metrics)
    #     return metrics

    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):

        avg_loss = self.validation_step_outputs['val_loss'].mean()
        output_val = self.validation_step_outputs['logits'].sigmoid().cpu().detach().numpy()
        target_val = self.validation_step_outputs['targets'].cpu().detach().numpy()

        avg_score = sklearn.metrics.label_ranking_average_precision_score(target_val,output_val)

        self.log('val_accuracy', avg_score)
        if avg_score > self.global_acc:
            self.global_acc = avg_score
            self.log('best_acc', avg_score)
        else:
            self.log('best_acc', self.global_acc)
        self.validation_step_outputs.clear()  # free memory
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}    
    
    
    

In [None]:


def train(new_wandb):
    # define the logger
    logger = WandbLogger(project='Bird2023', log_model="all", name=new_wandb.config['exp_name'])

    # define the dataset and dataloader
    train_dataset = MyDataset(data=data, 
                              mode='train',
                              cross_val_data_index=new_wandb.config['cross_val_data_index'], 
                              random_state=new_wandb.config['random_state'], 
                              test_size=new_wandb.config['test_size'])
    val_dataset = MyDataset(data=data,  
                            mode='val',
                            cross_val_data_index=new_wandb.config['cross_val_data_index'], 
                            random_state=new_wandb.config['random_state'], 
                            test_size=new_wandb.config['test_size'])
    
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=new_wandb.config['batch_size'], shuffle=True, pin_memory=True)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=new_wandb.config['batch_size'], shuffle=False, pin_memory=True, drop_last=False)

    # define the model
    audio_model = BirdClefModel(model_name=new_wandb.config['model'])

    # define callbacks
    early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=new_wandb.config['PATIENCE'], verbose= True, mode="max")
    checkpoint_callback = ModelCheckpoint(dirpath=new_wandb.config['save_path'],
                                            save_top_k=1,
                                            save_last= True,
                                            save_weights_only=False,
                                            filename= f'./{new_wandb.config["model"]}_loss',
                                            verbose= True,
                                            monitor='val_accuracy',
                                            mode='max',
                                            auto_insert_metric_name = True)
    callbacks_to_use = [checkpoint_callback, early_stop_callback]

    # define the trainer
    trainer = Trainer(
        val_check_interval=1.0,
        deterministic=True,
        max_epochs=new_wandb.config['epochs'],
        logger=logger,   
        callbacks=callbacks_to_use,
        precision=new_wandb.config['PRECISION'], accelerator="gpu",
        devices=[7],
        num_sanity_val_steps=0 
    )

    # conduct train and test
    trainer.fit(audio_model, train_dataloaders = train_dataloader, val_dataloaders = val_dataloader)
    # audio_model.eval()
    # with torch.no_grad():                
    #     trainer.test(dataloaders=val_dataloader)  

In [None]:
# cross-validation here (modify corresponding parameters and run trian, just so simple)
new_exp_name = 'aa'
wandb.init(project="Bird2023", config=config, name=new_exp_name)
wandb.config['epochs'] = 5
wandb.config['model'] = 'tf_efficientnet_b3_ns'
wandb.config['exp_name'] = new_exp_name
train(wandb)
wandb.finish()

gc.collect()
torch.cuda.empty_cache()


In [None]:
new_exp_name = 'bb'
wandb.init(project="Bird2023", config=config, name=new_exp_name)
wandb.config['epochs'] = 3
wandb.config['model'] = 'tf_efficientnet_b0_ns'
wandb.config['exp_name'] = new_exp_name

train(wandb)
wandb.finish()

gc.collect()
torch.cuda.empty_cache()