In [None]:
# import os
import gc
import timm
import numpy as np
import pandas as pd
import sklearn.metrics
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts #, ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint #, EarlyStopping,BackboneFinetuning

import wandb
import albumentations as A
from torchtoolbox.tools import mixup_data, mixup_criterion
import soundfile as sf

import warnings
warnings.filterwarnings('ignore')


In [None]:
class Config:
    use_aug = False
    num_classes = 264
    batch_size = 64
    epochs = 50 # 12, 50
    PRECISION = 16    
    PATIENCE = 8    
    seed = 2023
    model = "tf_efficientnet_b2_ns"
    pretrained = True            
    weight_decay = 1e-3
    use_mixup = True
    mixup_alpha = 0.2   
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    

    data_root = "./"
    train_images = "./specs/train/"
    valid_images = "./specs/valid/"
    train_path = "./train.csv"
    valid_path = "./valid.csv"
    
    
    SR = 32000
    DURATION = 5
    MAX_READ_SAMPLES = 5
    LR = 5e-4 #5e-4
    save_path =  "./exp1/"
    bird_name_path = 'bird_names.pickle3'
    pickle_file_path = 'train_mel(dB,sr=32k,bin=128).shuffled_subset1.pickle3' #'train_mel(dB,sr=32k,bin=128).pickle3'

In [None]:
pl.seed_everything(Config.seed, workers=True)

In [None]:
import pickle
def load_pickle(fname):
    f = open(fname, 'rb')
    out = pickle.load(f)
    f.close()
    return out


In [None]:
def config_to_dict(cfg):
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))

In [None]:
df_train = pd.read_csv(Config.train_path)
df_valid = pd.read_csv(Config.valid_path)
df_train.head()

In [None]:
Config.num_classes = len(df_train.primary_label.unique())
print(Config.num_classes, type(df_train))

In [None]:
df_train = pd.concat([df_train, pd.get_dummies(df_train['primary_label'])], axis=1)
df_valid = pd.concat([df_valid, pd.get_dummies(df_valid['primary_label'])], axis=1)

# convert the class name into one-hot encoding
df_train.head()

## Create & Fill birds with 0 samples in validation

In [None]:
birds = list(df_train.primary_label.unique())
print(len(birds))

In [None]:
missing_birds = list(set(list(df_train.primary_label.unique())).difference(list(df_valid.primary_label.unique())))
non_missing_birds = list(set(list(df_train.primary_label.unique())).difference(missing_birds))
print(len(missing_birds), len(non_missing_birds))

In [None]:
df_valid[missing_birds] = 0.0

# print(df_valid.primary_label)
# print(df_valid.iloc[:,17:])
# print('---------------------')
# print(df_train.columns)
df_valid = df_valid[df_train.columns] ## Fix order
# print(df_valid.iloc[:,17:])

In [None]:
df_train.iloc[:,17:]

In [None]:

def get_train_transform():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.OneOf([A.Cutout(max_h_size=5, max_w_size=16), A.CoarseDropout(max_holes=4),], p=0.5),
        ])

In [None]:
class BirdDataset(Dataset):
    def __init__(self, df, sr = Config.SR, duration = Config.DURATION, augmentations = None, train = True):
        self.df = df
        self.sr = sr 
        self.train = train
        self.duration = duration
        self.augmentations = augmentations
        if train:
            self.img_dir = Config.train_images
        else:
            self.img_dir = Config.valid_images

    def __len__(self):
        return len(self.df)

    @staticmethod
    def normalize(image):
        image = image / 255.0
        #image = torch.stack([image, image, image])
        return image

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        impath = self.img_dir + f"{row.filename}.npy"
        image = np.load(str(impath))[:Config.MAX_READ_SAMPLES]
        # print(type(image), image.shape, len(image))
        ########## RANDOM SAMPLING ################
        if self.train:
            image = image[np.random.choice(len(image))]
        else:
            image = image[0]
        #####################################################################
        image = torch.tensor(image).float()
        if self.augmentations:
            image = self.augmentations(image.unsqueeze(0)).squeeze()
        image.size()
        image = torch.stack([image, image, image])
        image = self.normalize(image)
        return image, torch.tensor(row[17:]).float()


In [None]:
def get_fold_dls(df_train, df_valid):
    ds_train = BirdDataset(
        df_train, 
        sr = Config.SR,
        duration = Config.DURATION,
        augmentations = None,
        train = True
    )
    ds_val = BirdDataset(
        df_valid, 
        sr = Config.SR,
        duration = Config.DURATION,
        augmentations = None,
        train = False
    )
    dl_train = DataLoader(ds_train, batch_size=Config.batch_size , shuffle=True, num_workers = 2)    
    dl_val = DataLoader(ds_val, batch_size=Config.batch_size, num_workers = 2)
    return dl_train, dl_val, ds_train, ds_val

In [None]:
# our data
def load_pickle(fname):
    f = open(fname, 'rb')
    out = pickle.load(f)
    f.close()
    return out

In [None]:
try:
    print(Config.pickle_file_path)
    data = load_pickle(Config.pickle_file_path)  # see ogg2mel for how to convert ogg to mel_dict.pickle3
    bird_names, _ = np.unique(data['primary_label'], return_counts=True)
except FileNotFoundError:
    print("dataset not found! you can generate one by using ogg2mel.py")
    data=None
    exit(-1)

In [None]:
class MyDataset(Dataset):
    def __init__(self, data, mode):
        assert mode in ['train', 'test', 'val'], f'invalid mode {mode}!, mode must be [train | val | test]'
        self.mode = mode

        self.sr = 32000
        self.duration = 5
        self.audio_length = self.duration * self.sr

        self.name_label_2_int_label = load_pickle("name_label_2_int_label.pickle3")  # a dict which saves mapping from


        if mode == 'train':
            total_data_length = len(data['primary_label'])
            self.data = {k: v[:int(0.8*total_data_length)] for k, v in data.items()}  # 0~80% as train set
        elif mode == 'val':
            total_data_length = len(data['primary_label'])
            self.data = {k: v[int(0.8 * total_data_length):] for k, v in data.items()}  # 80%~100% as val set
        elif mode == 'test':  # in test mode, data is in DataFrame form
            self.data_test_df = data
        else:
            raise ValueError(f'no such mode {mode}')

    @staticmethod
    def normalize(image):
        image = image / 255.0
        #image = torch.stack([image, image, image])
        return image

    def __getitem__(self, index):
        if self.mode == 'train' or  self.mode == 'val':
            mel = self.data['mel'][index]
            
            mel = self.crop_or_pad(mel)
            mel = self.mono_to_color(mel)
            mel = self.normalize(mel)
            mel = torch.from_numpy(mel).unsqueeze(0).float()
            # print(mel.size())
            mel = mel.repeat(3,1,1)

            label = self.name_label_2_int_label[self.data['primary_label'][index]]  # 'bird name' -> idx
            label = torch.nn.functional.one_hot(label, num_classes=264).float()
            # print(label)
            return mel, label

        elif self.mode == 'test':
            return self.read_file(self.data_test_df.loc[index, "path"])

    def __len__(self):
        if self.mode == 'train' or  self.mode == 'val':
            return len(self.data['primary_label'])
        elif self.mode == 'test':
            return len(self.data_test_df)

    def crop_or_pad(self, m, th=313):  # 313=5s*32000Hz/512
        length = m.shape[1]
        if length <= th: # pad short
            while m.shape[1] < th:  # repeat padding until th
                m = np.concatenate([m, m],axis=1)
            m = m[:,0:th]
        else:  # crop longer audio
            start = np.random.randint(length - th)
            m = m[:,start:start+th]
        return m #torch.from_numpy(m).unsqueeze(0)

    ########## following methods are for test use only ##################

    def audio_to_image(self, audio, ):
        melspec = librosa.feature.melspectrogram(y=audio, sr=self.sr, n_mels=128, fmin=0, fmax=self.sr//2)
        melspec = librosa.power_to_db(melspec).astype(np.float32)
        image = mono_to_color(melspec)
        # image = self.normalize(image)
        return melspec

    def read_file(self, filepath):
        audio, orig_sr = sf.read(filepath, dtype="float32")

        # if self.resample and orig_sr != self.sr:
        #     audio = librosa.resample(audio, orig_sr, self.sr, res_type=self.res_type)

        audios = []
        for i in range(self.audio_length, len(audio) + self.audio_length, self.audio_length):
            start = max(0, i - self.audio_length)
            end = start + self.audio_length
            audios.append(audio[start:end])

        if len(audios[-1]) < self.audio_length:
            audios = audios[:-1]

        images = [self.audio_to_image(audio) for audio in audios]
        images = np.stack(images)
        return images
    
    def mono_to_color(self, X, eps=1e-6, mean=None, std=None):
        mean = mean or X.mean()
        std = std or X.std()
        X = (X - mean) / (std + eps)
        
        _min, _max = X.min(), X.max()

        if (_max - _min) > eps:
            V = np.clip(X, _min, _max)
            V = 255 * (V - _min) / (_max - _min)
            V = V.astype(np.uint8)
        else:
            V = np.zeros_like(X, dtype=np.uint8)

        return V


In [None]:
train_dataset = MyDataset(data=data, mode='train')
val_dataset = MyDataset(data=data, mode='val')

train_dataloader = DataLoader(dataset=train_dataset, batch_size=Config.batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=Config.batch_size, shuffle=False, pin_memory=True, drop_last=False)

In [None]:
def show_batch(img_ds, num_items, num_rows, num_cols, predict_arr=None):
    fig = plt.figure(figsize=(12, 6))    
    img_index = np.random.randint(0, len(img_ds)-1, num_items)
    for index, img_index in enumerate(img_index):  # list first 9 images
        img, lb = img_ds[img_index]        
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[])
        if isinstance(img, torch.Tensor):
            img = img.detach().numpy()
        if isinstance(img, np.ndarray):
            img = img.transpose(1, 2, 0)
            ax.imshow(img)        
            
        title = f"Spec"
        ax.set_title(title)  

In [None]:
dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)
show_batch(ds_val, 8, 2, 4)

In [None]:

def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=Config.weight_decay
        )
    interval = "epoch"
    
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=Config.epochs, 
                            T_mult=1, 
                            eta_min=1e-6, 
                            last_epoch=-1
                        )

    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

In [None]:
class BirdClefModel(pl.LightningModule):
    def __init__(self, model_name=Config.model, num_classes = Config.num_classes, pretrained = Config.pretrained):
        super().__init__()
        self.num_classes = num_classes

        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        if 'res' in model_name:
            self.in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Linear(self.in_features, num_classes)
        elif 'dense' in model_name:
            self.in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Linear(self.in_features, num_classes)
        elif 'efficientnet' in model_name:
            self.in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Sequential(
                nn.Linear(self.in_features, num_classes)
            )
        
        self.loss_function = nn.BCEWithLogitsLoss() 

    def forward(self,images):
        logits = self.backbone(images)
        return logits
        
    def configure_optimizers(self):
        return get_optimizer(lr=Config.LR, params=self.parameters())

    def train_with_mixup(self, X, y):
        X, y_a, y_b, lam = mixup_data(X, y, alpha=Config.mixup_alpha)
        y_pred = self(X)
        loss_mixup = mixup_criterion(F.cross_entropy, y_pred, y_a, y_b, lam)
        return loss_mixup

    def training_step(self, batch, batch_idx):
        image, target = batch        
        if Config.use_mixup:
            loss = self.train_with_mixup(image, target)
        else:
            y_pred = self(image)
            loss = self.loss_function(y_pred,target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss        

    def validation_step(self, batch, batch_idx):
        image, target = batch     
        y_pred = self(image)
        val_loss = self.loss_function(y_pred, target)
        
        self.validation_step_outputs = {"val_loss": val_loss, "logits": y_pred, "targets": target}
        self.log("val_loss", val_loss)

        return self.validation_step_outputs
    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):

        avg_loss = self.validation_step_outputs['val_loss'].mean()
        output_val = self.validation_step_outputs['logits'].sigmoid().cpu().detach().numpy()
        target_val = self.validation_step_outputs['targets'].cpu().detach().numpy()

        avg_score = sklearn.metrics.label_ranking_average_precision_score(target_val,output_val)

        self.log('val_accuracy', avg_score)
        self.validation_step_outputs.clear()  # free memory
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}    
    
    
    

In [None]:
# define the logger
wandb_logger = WandbLogger(project='Bird2023', log_model="all", name='efficientnet_b3_epoch_50')
logger = wandb_logger


# define the data
dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)

# define the model
audio_model = BirdClefModel()


# early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=Config.PATIENCE, verbose= True, mode="min")
checkpoint_callback = ModelCheckpoint(dirpath=Config.save_path,
                                        save_top_k=1,
                                        save_last= True,
                                        save_weights_only=False,
                                        filename= f'./{Config.model}_loss',
                                        verbose= True,
                                        monitor='val_accuracy',
                                        mode='max',
                                        auto_insert_metric_name = True)

callbacks_to_use = [checkpoint_callback]#,early_stop_callback]

# define the trainer
trainer = Trainer(
    val_check_interval=0.5,
    deterministic=True,
    max_epochs=Config.epochs,
    logger=logger,   
    callbacks=callbacks_to_use,
    precision=Config.PRECISION, accelerator="gpu",
    devices=[7],
    num_sanity_val_steps=0 
)

# train the model
# trainer.fit(audio_model, train_dataloaders = dl_train, val_dataloaders = val_dataloader)#dl_val)                

trainer.fit(audio_model, train_dataloaders = train_dataloader, val_dataloaders = val_dataloader)#dl_val)                


# close the wandb run and free memory
wandb.finish()
gc.collect()
torch.cuda.empty_cache()
