In [39]:
import os
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
import torchmetrics

import timm

In [40]:
torch.set_float32_matmul_precision('high')

In [41]:
class Config:
    def __init__(self):
        # Spectrogram calculation parameters
        self.nfft = 1024
        self.num_fold = 5
        self.window_duration_in_sec = 5
        self.n_mels = 128 
        self.width = 256
        
        # Data parameters
        self.max_time = 5
        self.sample_rate = 32000
        self.audio_length = self.max_time * self.sample_rate
        self.min_frequency = 0
        self.max_frequency = 16000

        # Model parameters
        self.model_name = 'tf_efficientnet_b1_ns'
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = 64
        self.epochs = 10
        self.learning_rate = 1e-3
        self.num_classes = 182

In [42]:
def crop_or_pad(y, length, is_train=True, start=None):
    if len(y) < length:
        y = np.concatenate([y, np.zeros(length - len(y))])
        
        n_repeats = length // len(y)
        epsilon = length % len(y)
        
        y = np.concatenate([y] * n_repeats + [y[:epsilon]])
        
    elif len(y) > length:
        if not is_train:
            start = start or 0
        else:
            start = start or np.random.randint(len(y) - length)

        y = y[start:start + length]

    return y


def mix_up(original_melspecs, original_labels, alpha=1.0):
    indices = torch.randperm(original_melspecs.size(0))

    lam = np.random.beta(alpha, alpha)

    augmented_melspecs = original_melspecs * lam + original_melspecs[indices] * (1 - lam)
    augmented_labels = [original_labels * lam, original_labels[indices] * (1 - lam)]

    return augmented_melspecs, augmented_labels
        

def cut_mix(original_melspecs, original_labels, alpha=1.0):
    indices = torch.randperm(original_melspecs.size(0))

    index1 = np.random.randint(0, mask.shape[2])
    index2 = index1 + np.random.randint(0, mask.shape[2] - index)
    mask = np.array([], dtype=np.bool_)
    mask.shape = original_melspecs.shape
    mask[:, :, index1:index2, :] = True
    lam = np.random.beta(alpha, alpha)

    augmented_melspecs = original_melspecs * mask + original_melspecs[indices] * np.logical_not(mask)
    augmented_labels = [original_labels * lam, original_labels[indices] * (1 - lam)]

    return augmented_melspecs, augmented_labels
    
                             
def spec_augment(original_melspec,
                 freq_masking_max_percentage = 0.15, 
                 time_masking_max_percentage = 0.3):

    all_frames_num, all_freqs_num = original_melspec.shape

    # Frequency masking
    freq_percentage = np.random.uniform(0.0, freq_masking_max_percentage)
    num_freqs_to_mask = int(freq_percentage * all_freqs_num)
    f0 = int(np.random.uniform(low = 0.0, high = (all_freqs_num - num_freqs_to_mask)))
    
    original_melspec[:, f0:(f0 + num_freqs_to_mask)] = 0

    # Time masking
    time_percentage = np.random.uniform(0.0, time_masking_max_percentage)
    num_frames_to_mask = int(time_percentage * all_frames_num)
    t0 = int(np.random.uniform(low = 0.0, high = (all_frames_num - num_frames_to_mask)))
    
    original_melspec[t0:(t0 + num_frames_to_mask), :] = 0
    
    return original_melspec

In [43]:
class BirdCLEFDataset(Dataset):
    def __init__(self, data, path_to_data: str, config: Config, valid=False):
        super().__init__()
        self.path_to_data = path_to_data
        self.data = data
        self.config = config
        self.audio_length = self.config.sample_rate * self.config.window_duration_in_sec
        self.valid = valid
        
    def make_melspec(self, audio_data):
        melspec = librosa.feature.melspectrogram(
            y=audio_data, sr=self.config.sample_rate, n_mels=self.config.n_mels, 
            fmin=self.config.min_frequency, fmax=self.config.max_frequency,
        )

        return librosa.power_to_db(melspec).astype(np.float32)
    
    def normalize(self, image):
        image = image.astype("float32", copy=False) / 255.0
        if image.shape[1] > self.config.width:
            offset = np.random.randint(0, image.shape[1] - self.config.width)
            image = image[:128, offset:offset+self.config.width]
        else:
            zeroes = np.zeros((128, self.config.width - image.shape[1]))
            image = np.concatenate([image, zeroes], axis=1, dtype=np.float32)
          
        image = np.stack([image, image, image], axis=0)
        return image
    
    def audio_to_image(self, audio):
        melspec = self.make_melspec(audio)
        image = self.normalize(melspec)
        return torch.tensor(image)

    def read_data(self, row):
        path = os.path.join(self.path_to_data, row['path'])
        audio, _ = librosa.load(path, sr=self.config.sample_rate)
        
        if self.valid:
            audios = []
            for i in range(self.audio_length, len(audio) + self.audio_length, self.audio_length):
                start = max(0, i - self.audio_length)
                end = start + self.audio_length
                audios.append(audio[start:end])

            if len(audios[-1]) < self.audio_length:
                audios = audios[:-1]

            images = [self.audio_to_image(audio) for audio in audios]
            images = np.stack(images)
            
        else:
            images = self.audio_to_image(audio)  
        
        labels = torch.tensor(row[3:]).float()
        
        return images, labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.read_data(self.data.loc[idx])

In [44]:
data = pd.read_csv('/home/asphodel/Downloads/birdclef-2024/train_metadata.csv')
data = pd.concat(
    [
        pd.Series(data['primary_label']), 
        pd.Series(data['type']), 
        pd.Series(data['filename'], name='path')
    ], 
    axis=1, names=['primary_label', 'type', 'path']
)

birds = list(pd.get_dummies(data['primary_label']).columns)
filenames = data.path.values.tolist()

data = pd.concat([data, pd.get_dummies(data['primary_label'])], axis=1)

train_data, valid_data = train_test_split(data, train_size=0.7, shuffle=True)
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)

In [45]:
config = Config()
train_dataset = BirdCLEFDataset(train_data, '/home/asphodel/Downloads/birdclef-2024/train_audio', config)
test_dataset = BirdCLEFDataset(valid_data, '/home/asphodel/Downloads/birdclef-2024/train_audio', config)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, num_workers=3)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size, num_workers=3)

In [46]:
class MyModel(pl.LightningModule):
    def __init__(self, config: Config):
        super(MyModel, self).__init__()
        self.config = config
        self.model = timm.create_model(config.model_name, pretrained=True)
        self.model.classifier = torch.nn.Linear(self.model.num_features, config.num_classes)

        self.f1 = torchmetrics.F1Score(task='binary', num_classes=config.num_classes, average='macro')
        self.precision = torchmetrics.Precision(task='binary', num_classes=config.num_classes, average='macro')
        self.recall = torchmetrics.Recall(task='binary', num_classes=config.num_classes, average='macro')
        self.aug_roc = torchmetrics.AUROC(task='binary', num_classes=config.num_classes, average='macro')

    def forward(self, x):
        return self.model(x)
    
    def step(self, batch, stage: str):
        x, y = batch
        predict = self(x)
        loss = F.cross_entropy(predict, y)
        
        auc_roc = self.aug_roc(y, predict)
        precision = self.precision(predict, y)
        recall = self.recall(predict, y)
        f1 = self.f1(predict, y)

        self.log(f'{stage}_loss', loss)
        self.log(f'{stage}_auc_roc', auc_roc)
        self.log(f'{stage}_precision', precision)
        self.log(f'{stage}_recall', recall)
        self.log(f'{stage}_f1', f1)
        
        return loss
    
    def training_step(self, batch, batch_idx):
        return self.step(batch, 'train')

    def test_step(self, batch, batch_idx):
        return self.step(batch, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.learning_rate)
        return optimizer

In [None]:
model = MyModel(config)

In [None]:
trainer = pl.Trainer(max_epochs=10)  
trainer.fit(model, train_dataloader)