# Auditory Classification of Birds using ConvMixer Architecture.
##### author: amgawishx@github.com

### Section 1
This section contains library imports and hyperparameters of the entire notebook,
you can customize any aspect of the code from here.

In [1]:
### Section 1
import os # for reading files & pathes
import tqdm # for pretty progress bar

# the main components
import torch
import torchaudio
import torchvision
import torch.nn as nn

import polars as ps # for reading & manipulating csv files
import matplotlib.pyplot as plt # visualization

# helpers to manage preprocessing & memory
from torch.cuda import OutOfMemoryError
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from torch.nn.utils.rnn import pad_sequence

from functools import reduce

data_path = "/kaggle/input/birdclef-2023/train_audio/"
meta_path = "/kaggle/input/birdclef-2023/train_metadata.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"

sr = 32000 # audio sample rate
nfft = 1024 # number of FFT points, 2^N, N = 10
no_classes = 264 # number of damn birds
lr = 1e-3 # learning
dims = 32 # latent output dimension of the network
depth = 10 # depth of the the network
bsize = 32 # batch size
epochs = 1000 # number of epochs (theoretically)
n_fc = 1024 # the width of the fully connected layers

### Section 2
This section defines the class used to load the training data and the
collate function used to preprocess it; Since every audio file is of variable length,
we need a way to standarize the waveform length, and we do this in the
`len_collate` function via the `rnn.pad_sequence` operation.

In [2]:
### Section 2

class BirdsDataset(Dataset):
    """
    A class to preprocess and load training audio data
    ---
    meta_path: str -> path to the train metadata file.
    train_dir: str -> path to the training data files.
    transform: nn.Module -> a transform to perform on the data while loading.
    """
    def __init__(self, meta_path: str,
                 train_dir: str,
                 sample_rate: int = 32000,
                 transform: nn.Module = None,
                 device: str = "cpu") -> None:
        named_labels = list(ps.read_csv(meta_path)[:,0].unique()) # get class names
        self.data = []
        # populate self.data with pairs of an audio file and its label
        for label in named_labels:
            files_path = os.path.join(train_dir,label)
            files = os.listdir(files_path)
            for file in files:
                self.data.append((os.path.join(files_path, file),
                                  named_labels.index(label)))
        self.dir = train_dir
        self.transform = transform
        self.sample_rate = sample_rate
        self.device = device
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx) -> tuple:
        audio = self._audio_loader(self.data[idx][0])
        label = self.data[idx][1]
        if self.transform:
            audio = self.transform(audio)
        return audio, label
        
    def _audio_loader(self, audio_path):
        waveform, sr = torchaudio.load(audio_path, normalize = True)
        resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
        return resampler(waveform[0,:].to(self.device))

def len_collate(batch, device):
    """
    A function to pass it to dataloader to pad each batch of
    time series to the same length, loaded onto `device`.
    """
    waveforms, labels = [], []
    for waveform, label in batch:
        try:
            waveform = torch.transpose(waveform,0,1) # swap frequency & time dimensions
        except IndexError:
            pass
        waveforms.append(waveform)
        labels.append(torch.tensor(label).to(device))
    padded_waveforms = torch.stack([waveform.reshape((1, *waveform.shape))
                          for waveform in pad_sequence(waveforms, batch_first = True)], 0)
    del waveforms, batch
    return padded_waveforms, torch.stack(labels, dim = 0)

### Section 3
Here we define the class model as defined per the paper titled: [Patches is All You Need?](https://arxiv.org/pdf/2201.09792.pdf), the paper defines a convolutional neural network that uses convolutional layers to mix between spatial & channel-wise information to better understand the underlying representation of the data, akin to puzzle solving.

![The ConvMixer model](https://raw.githubusercontent.com/IbrahimSobh/Transformers/main/images/conmixer01.png)

We also define the MEL Spectrogram transform that translates an audio problem to a vision one by producing an image like the following:

![MEL Spectrogram](https://librosa.org/doc/main/_images/librosa-feature-melspectrogram-1.png)

In [3]:
### Section 3

class Residual(nn.Module):
    """
    A skip-connection block used in the ConvMixer model.
    """
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

def SpectrumTransform(nfft):
    """
    Performs a MEL Spectrogram transform on input waveform.
    Used to extract auditory features of a sound file.
    """
    return nn.Sequential(torchaudio.transforms.Spectrogram(n_fft=nfft, power=2),
                    torchaudio.transforms.MelScale(n_stft = nfft // 2 + 1,
                                                  sample_rate = sr),
                   torchaudio.transforms.AmplitudeToDB()).to(device)

def ConvMixer(dim, depth, no_classes,
              resize_shape = (100, 200),
              kernel_size = 9, patch_size = 5):
    """
    The main classifier of the model, takes an input
    a MEL Spectrogram in dB and performs convolutional
    spatial and channel mixing between the inputs then pass
    it to a classification head.
    """
    return nn.Sequential(
        # preprocessing tail
        torchvision.transforms.Resize(resize_shape),
        nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        
        # mixer layers
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size,
                          groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            )),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        ) for _ in range(depth)],
        
        # classification head
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_fc),
        nn.ReLU(),
        nn.Linear(n_fc, no_classes),
        nn.Softmax(),
    )

### Section 4
In this section we define helper functions to evaluate metrics of performance of our network and train it.

In [4]:
### Section 4

def accuracy(output, target):
    """
    A helper function to evalute the accuracy of
    model's prediction over a batch of inputs.
    """
    with torch.no_grad():
        pred = output.argmax(dim=1)
        correct = (pred == target.argmax(dim=1)).sum().item()
        total = target.size(0)
    return correct / total * 100

def train(dataloader, model, loss_fn, optimizer):
    """
    The function used to train a model on data from dataloder
    given a loss function and an optimizer.
    """
    data = tqdm.tqdm(enumerate(dataloader)) # get a progress bar
    size = len(dataloader.dataset)
    accuracies = [] # store accuracy history
    losses = [] # store losses history
    model.train() # enable training mode for the network
    for batch, (X, y) in data:
        y = torch.nn.functional.one_hot(y, num_classes=no_classes).float() # one-hot encode the classes
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # get metrics
        loss, current = loss.item(), batch * len(X)
        with torch.no_grad():
            losses.append(loss)
            accuracies.append(accuracy(pred, y))
            mean_accuracy = sum(accuracies)/len(accuracies)
            mean_loss = sum(losses)/len(losses)
        data.set_description(f"loss: {mean_loss:<7f} [{current:>5d}/{size:>5d}],"+\
              f" accuracy: {mean_accuracy:<3f}%")
        # clear memory
        del X, y
        torch.cuda.empty_cache()
    return losses, accuracies # return losses & accuracies for performance graph

### Section 5
The final section where we load our data, instantiate our model and start the training loop; Training originally has been performed on 2 GPU T4 with 16GB each and system RAM of 13GB. If you are using a single GPU, then just remove the `DataParallel` line.

In [8]:
### Section 5

data = BirdsDataset(meta_path, data_path,
                    sample_rate = sr, device = device) # load training dataset

dataloader = DataLoader(data, batch_size = bsize,
                        collate_fn = lambda x : len_collate(x, device), 
                        shuffle = True) # preprocess and prepare it for loading

model = nn.Sequential(
    SpectrumTransform(nfft),
    ConvMixer(dims, depth, no_classes)
).to(device) # instantiate the model

model = DataParallel(model, device_ids=[0, 1]) # parallel train the model on cuda:0 & cuda:1

# run the training loop with cross-entropy loss and RMS Propagation optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr = lr)
total_losses = []
total_accuracies = []
for _ in range(epochs):
    if _%10 == 0: print(f"Epoch {_+1}:")
    try:
        current_loss, current_accuracy = train(dataloader, model, loss_fn, optimizer)
        total_losses += current_loss
        total_accuracies += current_accuracy
    except OutOfMemoryError: # sometimes the gpu's memory get congested so we need to clear it
           torch.cuda.empty_cache()