## Music Style Detector : CNN with Mel Spectogram
### Sweeps

This notebook contains code to perform sweep over hyperparameters for CNN with Mel Spectogram.

### 1. Imports and setup

In [55]:
# IMPORTS
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Dataset, random_split
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

import wandb
import pandas as pd
import random
import os

# CONSTANTS
DATA_DIR = os.path.abspath(os.path.join(os.getcwd(), '../..', 'models', 'genre_detector', 'data'))
AUDIO_DIR = os.path.join(DATA_DIR, 'raw', 'audio')
TRAIN_DF = pd.read_csv(os.path.join(DATA_DIR, 'prepared', 'train_genres.csv'))
NB_CLASSES = len(TRAIN_DF['genre_id'].unique())
ID_TO_LABEL = TRAIN_DF.set_index('genre_id')['genre_label'].to_dict()

# SWEEP CONFIG
sweep_config = {
    'name': 'CNN Mel Spectogram Sweep',
    'method': 'bayes',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'audio_duration': {
            'values': [10000, 20000, 30000]
        },
        'sample_rate': {
            'values': [44100, 48000]
        },
        'n_channels': {
            'values': [1, 2]
        },
        'time_shift': {
            'values': [0.2, 0.3, 0.4]
        },
        'batch_size': {
            'values': [16, 32, 64]
        },
        'lr': {
            'values': [0.001, 0.0001]
        },
        'dropout': {
            'values': [0, 0.2, 0.5]
        }
    }
}

In [56]:
sweep_id = wandb.sweep(sweep_config, project='genre-detector_sweep', entity='mlodimage')

Create sweep with ID: 8j0jr9ql
Sweep URL: https://wandb.ai/mlodimage/genre-detector_sweep/sweeps/8j0jr9ql


### 2. Data pre-processing

In [57]:
class AudioUtils():
    """
    Utility class for audio processing.
    """
    @staticmethod
    def open(audio_file: str):
        """
        Load an audio file. Return the signal as a tensor and the sample rate.
        :param audio_file : Path to the audio file.
        :type audio_file : str
        :return: signal as a tensor and the sample rate
        :rtype: Tuple[torch.Tensor, int]
        """
        signal, sample_rate = torchaudio.load(audio_file)
        return signal, sample_rate
    
    @staticmethod
    def rechannel(audio, new_channel):
        """
        Convert a given audio to the specified number of channels.
        :param audio: the audio, composed of the signal and the sample rate
        :type audio: Tuple[torch.Tensor, int]
        :param new_channel: the target number of channels
        :type new_channel: int
        :return: the audio with the target number of channels
        :rtype: Tuple[torch.Tensor, int]
        """
        signal, sample_rate = audio

        if signal.shape[0] == new_channel:
            # nothing to do as the signal already has the target number of channels
            return audio
        if new_channel == 1:
            # convert to mono by selecting only the first channel
            signal = signal[:1, :]
        else:
            # convert to stereo by duplicating the first channel
            signal = torch.cat([signal, signal])
        return signal, sample_rate
    
    @staticmethod
    def resample(audio, new_sample_rate):
        """
        Change the sample rate of the audio signal.
        :param audio: the audio, composed of the signal and the sample rate
        :type audio: Tuple[torch.Tensor, int]
        :param new_sample_rate: the target sample rate
        :type new_sample_rate: int
        :return: the audio with the target sample rate
        :rtype: Tuple[torch.Tensor, int]
        """
        signal, sample_rate = audio
        if sample_rate == new_sample_rate:
            # nothing to do
            return audio
        resample = torchaudio.transforms.Resample(sample_rate, new_sample_rate)
        signal = resample(signal)
        return signal, new_sample_rate
    
    @staticmethod
    def pad_truncate(audio, length):
        """
        Pad or truncate an audio signal to a fixed length (in ms).
        :param audio: the audio, composed of the signal and the sample rate
        :type audio: Tuple[torch.Tensor, int]
        :param length: the target length in ms
        :type length: int
        :return: the audio with the target length
        :rtype: Tuple[torch.Tensor, int]
        """
        signal, sample_rate = audio
        max_length = sample_rate//1000 * length

        if signal.shape[1] > max_length:
            signal = signal[:, :max_length]
        elif signal.shape[1] < max_length:
            padding = max_length - signal.shape[1]
            signal = F.pad(signal, (0, padding))
        return signal, sample_rate

    @staticmethod
    def time_shift(audio, shift_limit):
        """
        Shift the signal to the left or right by some percent. Values at the end
        are 'wrapped around' to the start of the transformed signal.
        :param audio: the audio, composed of the signal and the sample rate
        :type audio: Tuple[torch.Tensor, int]
        :param shift_limit: the maximum shift to apply (in percent)
        :type shift_limit: int
        :return: the shifted audio
        :rtype: Tuple[torch.Tensor, int]
        """
        signal, sample_rate = audio
        _, signal_length = signal.shape
        shift_amount = int(random.random() * shift_limit * signal_length)
        return (signal.roll(shift_amount), sample_rate)
    
    @staticmethod
    def mel_spectrogram(audio, n_mels=64, n_fft=2048, hop_length=None):
        """
        Create the mel spectogram for the given audio signal.
        :param audio: the audio, composed of the signal and the sample rate
        :type audio: Tuple[torch.Tensor, int]
        :param n_mels: the number of mel filterbanks
        :type n_mels: int
        :param n_fft: the size of the FFT
        :type n_fft: int
        :param hop_length: the length of hop between STFT windows
        :type hop_length: int
        :return: the mel spectogram
        :rtype: torch.Tensor
        """
        signal, sample_rate = audio
        
        mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels
        )(signal)

        # convert to decibels
        mel_spectrogram = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel_spectrogram)

        return mel_spectrogram

### 3. Training data loading

In [58]:
class GenreDataset(Dataset):
    """
    Dataset for the FMA dataset.
    """
    def __init__(self, df, audio_dir, config):
        """
        Constructor.
        :param df: the dataframe containing the audio files ids and their genre label
        :type df: pandas.DataFrame
        :param audio_dir: the directory containing the audio files
        :type audio_dir: str
        """
        self.fma_df = df
        self.audio_dir = audio_dir
        self.config = config
        
    def __len__(self):
        """
        Get the length of the dataset.
        :return: the length of the dataset
        :rtype: int
        """
        return len(self.fma_df)
    
    def __getitem__(self, idx):
        """
        Get the idx-th sample of the dataset.
        :param idx: the index of the sample
        :type idx: int
        :return: the idx-th sample of the dataset and its genre label
        :rtype: Tuple[torch.Tensor, int]
        """ 
        audio_file_path = os.path.join(self.audio_dir, str(self.fma_df.iloc[idx]['filename']))
        # get the genre class id
        genre_id = self.fma_df.iloc[idx]['genre_id']

        # load the audio file and apply the preprocessing
        audio = AudioUtils.open(audio_file_path)
        audio = AudioUtils.rechannel(audio, self.config.n_channels)
        audio = AudioUtils.resample(audio, self.config.sample_rate)
        audio = AudioUtils.pad_truncate(audio, self.config.audio_duration)
        audio = AudioUtils.time_shift(audio, self.config.time_shift)
        mel_spectrogram = AudioUtils.mel_spectrogram(audio)

        return (mel_spectrogram, genre_id)

In [59]:
def build_dataset(config):
    # load the data
    full_dataset = GenreDataset(TRAIN_DF, AUDIO_DIR, config)

    # random split
    nb_samples = len(full_dataset)
    nb_train_samples = int(nb_samples * 0.8)
    nb_val_samples = nb_samples - nb_train_samples
    train_dataset, val_dataset = random_split(full_dataset, [nb_train_samples, nb_val_samples])

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=20, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,  num_workers=20)

    return train_loader, val_loader

### 4. Model creation

In [60]:
class AudioCNN(pl.LightningModule):
    """
    Audio classification model.
    """
    def __init__(self, n_channels, dropout, lr):
        """
        Constructor.
        :param nb_channels: the number of channels in the input data
        :param nb_classes: the number of classes
        :type nb_classes: int
        """
        super(AudioCNN, self).__init__()

        print(n_channels)

        self.conv1 = nn.Conv2d(n_channels, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(8)
        self.drop1 = nn.Dropout2d(p=dropout)
        nn.init.kaiming_normal_(self.conv1.weight, a=0.1)
        self.conv1.bias.data.zero_()

        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm2d(16)
        self.drop2 = nn.Dropout2d(p=dropout)
        nn.init.kaiming_normal_(self.conv2.weight, a=0.1)
        self.conv2.bias.data.zero_()

        self.conv3 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu3 = nn.ReLU()
        self.bn3 = nn.BatchNorm2d(32)
        self.drop3 = nn.Dropout2d(p=dropout)
        nn.init.kaiming_normal_(self.conv3.weight, a=0.1)
        self.conv3.bias.data.zero_()

        self.conv4 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu4 = nn.ReLU()
        self.bn4 = nn.BatchNorm2d(64)
        self.drop4 = nn.Dropout2d(p=dropout)
        nn.init.kaiming_normal_(self.conv4.weight, a=0.1)
        self.conv4.bias.data.zero_()

        self.conv5 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu5 = nn.ReLU()
        self.bn5 = nn.BatchNorm2d(128)
        self.drop5 = nn.Dropout2d(p=dropout)
        nn.init.kaiming_normal_(self.conv5.weight, a=0.1)
        self.conv5.bias.data.zero_()

        self.ap = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(in_features=128, out_features=NB_CLASSES, bias=True)

        # loss function
        self.loss = nn.CrossEntropyLoss()
        # optimizer parameters
        self.lr = lr

        # save hyperparameters
        self.save_hyperparameters()

    def forward(self, x):
        """
        Forward pass.
        :param x: the input
        :type x: torch.Tensor
        :return: the output
        :rtype: torch.Tensor
        """
        x = self.drop1(self.bn1(self.relu1(self.conv1(x))))
        x = self.drop2(self.bn2(self.relu2(self.conv2(x))))
        x = self.drop3(self.bn3(self.relu3(self.conv3(x))))
        x = self.drop4(self.bn4(self.relu4(self.conv4(x))))
        x = self.drop5(self.bn5(self.relu5(self.conv5(x))))
        x = self.ap(x)
        x = x.view(x.shape[0], -1)
        return self.linear(x)
    
    def training_step(self, batch, batch_idx):
        """
        Training step.
        :param batch: the batch
        :type batch: Tuple[torch.Tensor, torch.Tensor]
        :param batch_idx: the batch index
        :type batch_idx: int
        :return: the loss
        :rtype: torch.Tensor
        """
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        """
        Validation step.
        :param batch: the batch
        :type batch: Tuple[torch.Tensor, torch.Tensor]
        :param batch_idx: the batch index
        :type batch_idx: int
        :return: the loss
        :rtype: torch.Tensor
        """
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    
    def test_step(self, batch, batch_idx):
        """
        Test step.
        :param batch: the batch
        :type batch: Tuple[torch.Tensor, torch.Tensor]
        :param batch_idx: the batch index
        :type batch_idx: int
        :return: the loss
        :rtype: torch.Tensor
        """
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        self.log('test_loss', loss)
        self.log('test_acc', acc)
    
    def configure_optimizers(self):
        """
        Configure optimizers.
        :return: the optimizer
        :rtype: torch.optim.Optimizer
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def _get_preds_loss_accuracy(self, batch):
        """
        Get predictions, loss and accuracy."""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=NB_CLASSES)
        return preds, loss, acc

### 5. Training sweep

In [61]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # data
        print('Done')
        train_loader, val_loader = build_dataset(config)
        print('Done')
        # model
        model = AudioCNN(n_channels=config.n_channels, dropout=config.dropout, lr=config.lr)
        print('Done creating model')

        trainer = Trainer(
            max_epochs=2,
            callbacks=[EarlyStopping(monitor='val_loss', patience=5)],
            logger=WandbLogger())
        
        trainer.fit(model, train_loader, val_loader)

In [62]:
wandb.agent(sweep_id, train, count=3)
wandb.finish()

[34m[1mwandb[0m: Agent Starting Run: i71lwgaw with config:
[34m[1mwandb[0m: 	audio_duration: 20000
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	n_channels: 1
[34m[1mwandb[0m: 	sample_rate: 44100
[34m[1mwandb[0m: 	time_shift: 0.4
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Done
Done
1


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


AttributeError: 'ZMQDisplayPublisher' object has no attribute '_orig_publish'