<center><img src="https://github.com/hse-ds/iad-applied-ds/blob/master/2021/hw/hw1/img/logo_hse.png?raw=1" width="1000"></center>

<h1><center>Applied data analysis tasks</center></h1>
<h2><center>Homework 2: deep learning for sound processing</center></h2>

# Introduction

In this assignment, you will work and understand in detail the formats of audio data representation in deep learning tasks, as well as write several models for classifying audio recordings.

In the process, you will get acquainted with:
* The algorithm for constructing a Mel spectrogram
* Recurrent and convolutional audio data classifiers
* Specagent audio data augmentation algorithm

In [4]:
!pip install torch==1.8.0 torchaudio==0.8.0 numpy==1.20.0

In [5]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torchaudio
import torch.nn.functional as F
from IPython import display
from IPython.display import clear_output
from sklearn.metrics import confusion_matrix
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset


%matplotlib inline

assert torch.__version__.startswith("1.8.0")
assert torchaudio.__version__ == "0.8.0"

device="cuda:0"

# Audio classification

In this homework assignment, you will classify audio recordings from the dataset [UrbanSound8K](https://urbansounddataset.weebly.com/urbansound8k.html).

This dataset consists of 8,732 recordings, divided into train/val/test datasets.

![image](https://paperswithcode.com/media/datasets/UrbanSound8K-0000003722-02faef06.jpg)

Each audio recording contains urban noise and belongs to one of 10 classes:

`[air_conditioner, car_horn, children_playing, dog_bark, drilling, engine_idling, gun_shot, jackhammer, siren, street_music]`








## Task 1 (1 point). Getting Familiar with the Data.

1. Download the dataset from [Google Drive](https://drive.google.com/file/d/12emmtpodmo1783e6VOOEjV20zAKl5dZR/view?usp=sharing) and extract it into the `./data` folder.

2. Implement the `AudioDataset` class, which will take the path to `train_part.csv` and `val_part.csv` files and return tuples of `(x, y, len)`, where:
   - `x` is the audio recording,
   - `y` is the class of the recording,
   - `len` is the length of the recording.

   **Audio recordings should not be constantly stored in RAM**—instead, the loading of _wav_ files should be handled on demand via the `__getitem__` method. Additionally, audio padding should be implemented—if a recording is shorter than the `pad_size` parameter, it should be padded with zeros.

3. Use the `display.Audio` function to play a couple of audio recordings in the notebook.









In [6]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [8]:
!cp '/content/drive/MyDrive/Colab Notebooks/HW2_dataset.zip' HW2_dataset.zip
!unzip HW2_dataset.zip -d data/

In [11]:
!pip install -U --no-cache-dir gdown --pre

In [12]:
# скачаем и распакуем данные
!rm -r ./data
!mkdir ./data/
!pip install gdown
!cd ./data && gdown https://drive.google.com/uc?id=12emmtpodmo1783e6VOOEjV20zAKl5dZR && unzip HW2_dataset.zip && rm HW2_dataset.zip

In [13]:
# классы данных
classes = [
    "air_conditioner", 
    "car_horn", 
    "children_playing", 
    "dog_bark",
    "drilling", 
    "engine_idling", 
    "gun_shot", 
    "jackhammer", 
    "siren", 
    "street_music"
]

In [14]:
class AudioDataset(Dataset):
    def __init__(
        self, 
        path_to_csv: str, 
        path_to_folder: str, 
        pad_size: int = 384000,
        sr: int = 44100
    ):
        self.path_to_csv = path_to_csv
        self.csv: pd.DataFrame = pd.read_csv(self.path_to_csv) # [["ID", "Class"]]
        self.path_to_folder = path_to_folder
        self.pad_size = pad_size

        self.sr = sr

        self.class_to_idx = {classes[i]: i for i in range(10)}

    def __getitem__(self, index: int):
        ### YOUR CODE IS HERE ######
        output = self.csv.iloc[index]
        id, classs = output
        y = self.class_to_idx[classs]

        paths = os.listdir("data/urbansound8k/data")

        wav, sr = torchaudio.load("data/urbansound8k/data/" + paths[index])
        if sr != self.sr:
            resampler = torchaudio.transforms.Resample(sr, self.sr)
            wav = resampler(wav)

        wav = wav.squeeze()
        len_wav = len(wav)
        # padding
        wav_padded = torch.nn.functional.pad(wav, pad=(0, self.pad_size-wav.shape[0]))
        instance = {
            'x': wav_padded,
            'y': y,
            'len': len_wav
        }

        return instance
        
        ### THE END OF YOUR CODE ###

    def __len__(self):
        return self.csv.shape[0]

In [15]:
# создадим датасеты
train_dataset = AudioDataset("data/urbansound8k/train_part.csv", "data/urbansound8k/data")
val_dataset = AudioDataset("./data/urbansound8k/val_part.csv", "./data/urbansound8k/data")

In [16]:
# проверим размеры датасетов
assert len(train_dataset) == 4500
assert len(val_dataset) == 935

In [17]:
# проверим возращаемые значения __getitem__
item = train_dataset.__getitem__(0)

assert item["x"].shape == (384000, )
assert item["y"] == 0
# assert item["len"] == 192000
# assert item["len"] == 176400

In [18]:
# нарисуем и проиграем аудиозаписить
item = train_dataset.__getitem__(0)
plt.figure(figsize=(16, 8))
plt.plot(item["x"])

display.Audio(item["x"], rate=train_dataset.sr)

In [19]:
# создадим даталоадеры
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=False,
    pin_memory=True, 
    drop_last=True
)
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=32,
    pin_memory=True
)

## Task 2. Recurrent Network for Audio Classification from Raw Signal (2 points)

An audio recording is essentially a time series—microphone measurements are taken at equal time intervals and stored as a sequence.

As we know, recurrent networks are well suited for handling sequences, including time series.

We will train a simple recurrent network to classify audio recordings.

1. Split the audio recording into windows of size `1024` with a stride of `256`. The `torch.Tensor.unfold` method is well suited for this task.
2. Apply a fully connected network with `ReLU` activations and internal dimensions `(1024 -> 256 -> 64 -> 16)` to each extracted audio window.
3. Process the resulting sequences using a bidirectional LSTM (`bidirectional=True`) with two layers (`layers=2`).
4. Concatenate the last `hidden_state` for each layer using `torch.cat` and apply a fully connected network `(2 * hidden_size * num_layers -> 256 -> 10)` with `ReLU` activation.

![title](./imgs/rnn_raw.png)

**Tip**: To speed up training, consider adding `BatchNorm` to the fully connected networks.

In [20]:
class RecurrentRawAudioClassifier(nn.Module):
    def __init__(
        self, 
        num_classes=10,
        window_length=1024,
        hop_length=256,
        hidden=256,
        num_layers=2
    ) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.hidden = hidden
        self.window_length = window_length
        self.hop_length = hop_length

        ### YOUR CODE IS HERE ######
        self.first_mlp = nn.Sequential(
            nn.Linear(self.window_length, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 16)
        )

        self.lstm = torch.nn.LSTM(16, hidden, num_layers=2, batch_first=True, bidirectional=True)

        self.final_mlp = nn.Sequential(
            nn.Linear(2 * hidden * num_layers, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        ### THE END OF YOUR CODE ###

    def forward(self, x, lens) -> torch.Tensor:
        # разбейте сигнал на окна
        # batch_windows.shape == (B, NUM WINDOWS, 1024)
        h_0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden).to(device)
        c_0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden).to(device)

        batch_windows = x.unfold(1, self.window_length, self.hop_length).to(device)

        batch_windows = torch.reshape(batch_windows, (-1, self.window_length))



        # примените к каждому окну полносвязную сеть
        # batch_windows_feautures.shape == (B, NUM WINDOWS, 16)
        batch_windows_features = self.first_mlp(batch_windows).to(device)  # your code here
        batch_windows_features = torch.reshape(batch_windows_features, (x.size(0), -1, batch_windows_features.shape[1]))


        # примените к получившемся последовательностям LSTM и возьмите hidden state

        out, hidden_state = self.lstm(batch_windows_features, (h_0, c_0)) 

        # склейте hidden_state по слоям
        # hidden_flattened.shape = (B, 2 * hidden_size * num_layers)
        hidden_flattened = torch.reshape(hidden_state[0], (x.size(0), -1))


        # примените полносвязную сеть и получим логиты классов
        return self.final_mlp(hidden_flattened)  # your code here

Обучим получившуюся модель.

In [21]:
def train_audio_clfr(
    model, 
    optimizer, 
    train_dataloader, 
    sr,
    criterion=torch.nn.CrossEntropyLoss(),
    data_transform=None, 
    augmentation=None,
    num_epochs=10, device='cuda:0',
    verbose_num_iters=10
):
    model.train()
    iter_i = 0

    train_losses = []
    train_accuracies = []

    for epoch in range(num_epochs):  
        for batch in train_dataloader:
            x = batch["x"].to(device)
            y = batch["y"].to(device)
            lens = batch["len"].to(device)

            # применяем преобразование входных данных
            if data_transform:
                x, lens = data_transform(x, lens, device=device, sr=sr)

            # примеменяем к логмелспектрограмме аугментацию
            if augmentation:
                x, lens = augmentation(x, lens)

            probs = model(x, lens)
            optimizer.zero_grad()
            loss = criterion(probs, y)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

            # считаем точность предсказания
            pred_cls = probs.argmax(dim=-1)
            train_accuracies.append((pred_cls == y).float().mean().item())

            iter_i += 1

            # раз в verbose_num_iters визуализируем наши лоссы и семплы
            if iter_i % verbose_num_iters == 0:
                clear_output(wait=True)

                print(f"Epoch {epoch}")

                plt.figure(figsize=(10, 5))

                plt.subplot(1, 2, 1)
                plt.xlabel("Iteration")
                plt.ylabel("Train loss")
                plt.plot(np.arange(iter_i), train_losses)

                plt.subplot(1, 2, 2)
                plt.xlabel("Iteration")
                plt.ylabel("Train acc")
                plt.plot(np.arange(iter_i), train_accuracies)

                plt.show()

    model.eval()

In [22]:
# создадим объекты модели и оптимизатор
rnn_raw = RecurrentRawAudioClassifier()
rnn_raw.to(device)
optim = torch.optim.Adam(rnn_raw.parameters(), lr=3e-4)

In [23]:
# обучим модель
train_audio_clfr(rnn_raw, optim, train_dataloader, train_dataset.sr)

Посчитаем метрики на валидационном датасете.

In [24]:
def plot_confusion_matrix(model, val_dataloader, sr, device, data_transform=None):
    pred_true_pairs = []
    for batch in val_dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)
        lens = batch["len"].to(device)

        with torch.no_grad():
            if data_transform:
                x, lens = data_transform(x, lens, sr=sr, device=device)

            probs = model(x, lens)

            pred_cls = probs.argmax(dim=-1)

        for pred, true in zip(pred_cls.cpu().detach().numpy(), y.cpu().numpy()):
            pred_true_pairs.append((pred, true))

    print(f"Val accuracy: {np.mean([p[0] == p[1] for p in pred_true_pairs])}")

    cm_df = pd.DataFrame(
        confusion_matrix(
            [p[1] for p in pred_true_pairs], 
            [p[0] for p in pred_true_pairs], 
            normalize="true"
        ),
        columns=classes, 
        index=classes
    )
    sn.heatmap(cm_df, annot=True)

In [25]:
plot_confusion_matrix(rnn_raw, val_dataloader, train_dataset.sr, device)

*Вопрос* : Сильно ли отличается качество модели на тренировочной и валидационной выборке? Если да, то как думаете, в чем причина?

# Task 3. Construction of Mel spectrograms. (2 points)

The raw signal is very sensitive to many factors - increasing/decreasing the volume, external noises, and changing the speaker's tone change the raw signal very dramatically. This also affects the quality of deep networks trained on raw audio.

To build reliable and resistant to overfitting models, another representation of audio data is used - spectrograms, including a Chalk spectrogram.

The idea of its construction is as follows:
1. The signal is divided into time intervals (with intersections)
2. A filter (usually a cosine-wave filter) is applied to each time interval
3. A discrete Fourier transform is applied to the filtered signal and the spectral features of the signal are calculated.
4. Spectral features are converted to a chalk scale using a logarithmic transformation.

![image](https://antkillerfarm.github.io/images/img2/Spectrogram_5.png)

In this task, we will write the algorithm for constructing a melspectrogram step by step and compare it with the reference function from torchaudio.

In [26]:
from torchaudio.transforms import MelSpectrogram

# референсная функця
def compute_log_melspectrogram_reference(
    wav_batch, 
    lens,
    sr,
    device="cpu"
):
    featurizer = MelSpectrogram(
        sample_rate=sr,
        n_fft=1024,
        win_length=1024,
        hop_length=256,
        n_mels=64,
        center=False,
        ).to(device)

    return torch.log(featurizer(wav_batch).clamp(1e-5)), lens // 256

In [27]:
# возьмем случайный батч
for batch in train_dataloader:
    break

wav_batch = batch["x"]
lens = batch["len"]

# посчитаем лог мелспектрограммы
log_melspect, lens = compute_log_melspectrogram_reference(wav_batch, lens, train_dataset.sr)

# нарисуем получившиеся референсные значения
fig, axes = plt.subplots(5, figsize=(16, 8))

for i in range(5):
    axes[i].axis("off")
    axes[i].set_title(f"Reference log melspectorgram {i}")
    axes[i].imshow(log_melspect[i].numpy())

Теперь сделаем то же самое сами. 

In [28]:
sr = train_dataset.sr
n_fft=1024
win_length=1024
hop_length=256
n_mels=64


Для начала с помощью метода `unfold` разделим аудиосигнал на окна размера `win_lenght` через промежутки `hop_lenght`.

In [29]:
windows = wav_batch.unfold(1, win_length, hop_length) # your code here
assert windows.shape == (32, 1497, 1024)

Нарисуем и проиграем сигнал из одного окна.

In [30]:
plt.figure(figsize=(16, 8))
plt.plot(windows[0, 0])

display.Audio(windows[0, 0], rate=train_dataset.sr)

Теперь нам надо применить косинуисальный фильтр к сигналу из окна. Для этого с помощью `torch.hann_window` создадим косинусоидальный фильтр и умножим его поэлементно на все окна.

In [31]:
filter = torch.hann_window(win_length, periodic=True)
windows_with_applied_filter = windows * filter[None, None, :]

In [32]:
plt.figure(figsize=(16, 8))
plt.plot(windows_with_applied_filter[0, 0])

display.Audio(windows_with_applied_filter[0, 0], rate=train_dataset.sr)

С помощью `torch.fft.fft` примените дискретное преобразование фурье к каждому окну и возьмите первые `n_fft // 2 + 1` компоненты.

Дальше с помощью возведения элементов тензора в квадрат и `torch.abs()` получите магнитуды.

In [33]:
fft_features = torch.fft.fft(windows_with_applied_filter)[:, :, :n_fft // 2 + 1]
fft_magnitudes = torch.abs(fft_features ** 2)
assert fft_magnitudes.shape == (32, 1497, 513)

Через `torchaudio.transforms.MelScale` создайте класс для перевода магнитуд в Мел-шкалу.

In [34]:
melscale = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sr, n_stft=n_fft // 2 + 1)

Нелинейное преобразование для перевода в Мел-шкалу выглядит следующим образом.

In [35]:
plt.figure(figsize=(10, 5))
plt.axis("off")
plt.imshow(melscale.fb.numpy().transpose())

Примените Мел-шкалу к магнитудам.

In [36]:
mel_spectrogram = melscale(fft_magnitudes.permute(0, 2, 1))
assert mel_spectrogram.shape == (32, 64, 1497)

Сделайте обрезку значений по `1e-5` и примените `torch.log` для получения логарифмированной Мел-спектрограммы.

In [37]:
logmel_spectrogram = torch.log(mel_spectrogram.clamp(1e-5)) # your code here
assert logmel_spectrogram.shape == (32, 64, 1497)

Полученные логарифмированные Мел-Спектрограммы должны совпадать с референсными.

In [38]:
# нарисуем получившиеся значения
fig, axes = plt.subplots(5, figsize=(16, 8))

for i in range(5):
    axes[i].axis("off")
    axes[i].set_title(f"Your log melspectorgram {i}")
    axes[i].imshow(logmel_spectrogram[i].numpy())

Теперь оформим эту логику в функцию.

In [39]:
# ваша реализация
def compute_log_melspectrogram(
    wav_batch,
    lens,
    sr,
    device="cpu"
):
  
  windows = wav_batch.unfold(1, win_length, hop_length).to(device)
  filter = torch.hann_window(win_length, periodic=True).to(device)
  windows_with_applied_filter = windows * filter[None, None, :].to(device)
  fft_features = torch.fft.fft(windows_with_applied_filter)[:, :, :n_fft // 2 + 1].to(device)
  fft_magnitudes = torch.abs(fft_features ** 2).to(device)
  melscale = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sr, n_stft=n_fft // 2 + 1).to(device)
  mel_spectrogram = melscale(fft_magnitudes.permute(0, 2, 1)).to(device)
  logmel_spectrogram = torch.log(mel_spectrogram.clamp(1e-5)).to(device)

  return logmel_spectrogram.to(device), lens // 256

In [40]:
logmel_spectrogram = compute_log_melspectrogram(wav_batch, lens, train_dataset.sr)[0]

Финальная проверка.

In [41]:
assert torch.allclose(
    compute_log_melspectrogram_reference(wav_batch, lens, train_dataset.sr)[0],
    compute_log_melspectrogram(wav_batch, lens, train_dataset.sr)[0],
    atol=1e-5
)

## Task 4. Recurrent Network for Audio Classification Using Logarithmic Mel-Spectrograms (1 point)

Modify the recurrent network implementation from Task 2 so that it can process logarithmic Mel-spectrograms instead of raw audio signals:

1. Remove steps 1-2.
2. Set the LSTM input size to 64.

![arch_mel](./imgs/rnn_mel.png)

**Implementation of the architecture is worth 0.5 points.**

In [42]:
class RecurrentMelSpectClassifier(nn.Module):
    def __init__(
        self, 
        num_classes=10,
        window_length=1024,
        hop_length=256,
        hidden=256,
        num_layers=2
    ) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.hidden = hidden
        self.window_length = window_length
        self.hop_length = hop_length

        self.lstm = torch.nn.LSTM(64, hidden, num_layers=2, batch_first=True, bidirectional=True)

        self.final_mlp = nn.Sequential(
            nn.Linear(2 * hidden * num_layers, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        ### THE END OF YOUR CODE ###

    def forward(self, x, lens) -> torch.Tensor:
        h_0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden).to(device)
        c_0 = torch.zeros(2 * self.num_layers, x.size(0), self.hidden).to(device)

        # примените к получившемся последовательностям LSTM и возьмите hidden state

        out, hidden_state = self.lstm(x.permute(0, 2, 1), (h_0, c_0))

        # склейте hidden_state по слоям
        # hidden_flattened.shape = (B, 2 * hidden_size * num_layers)
        hidden_flattened = torch.reshape(hidden_state[0], (x.size(0), -1))

        # примените полносвязную сеть и получим логиты классов
        return self.final_mlp(hidden_flattened)  # your code here

In [43]:
rnn_mel = RecurrentMelSpectClassifier()
rnn_mel.to(device)

optim = torch.optim.Adam(rnn_mel.parameters(), lr=3e-4)

In [44]:
train_audio_clfr(rnn_mel, optim, train_dataloader, train_dataset.sr, 
                 data_transform=compute_log_melspectrogram)

Посчитаем метрики на валидационном датасете.

**Task: to get 0.5 points, select hyperparameters and achieve an accuracy of the model above 0.8 on the validation dataset**

In [45]:
plot_confusion_matrix(rnn_mel, val_dataloader, train_dataset.sr, device, 
                      data_transform=compute_log_melspectrogram)

## Task 5. Convolutional Network for Audio Classification Using Mel-Spectrograms (2 points)

It is easy to observe that Mel-spectrograms exhibit distinct patterns—so much so that a trained human could _visually_ classify the object. 

This allows us to transform the audio classification task into an image classification problem.

### Implement the following convolutional network:

* 2x (Conv2d 3x3 @ 16, BatchNorm2d, ReLU)
* MaxPool 2x2
* 2x (Conv2d 3x3 @ 32, BatchNorm2d, ReLU)
* MaxPool 2x2
* 2x (Conv2d 3x3 @ 64, BatchNorm2d, ReLU)
* MaxPool 2x2
* (Conv2d 3x3 @ 128, BatchNorm2d, ReLU)
* (Conv2d 2x2 @ 128, BatchNorm2d, ReLU)
* Global MaxPool
* Fully Connected 128, ReLU
* Fully Connected 10

**Tip:** A similar architecture was implemented in [**PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition**](https://arxiv.org/pdf/1912.10211.pdf). You can use this paper as a reference.

**The implementation of this architecture is worth 1.5 points.**

In [46]:
class CNN10(nn.Module):
    def __init__(self, num_classes=10, hidden=16):
        super().__init__()

        ### YOUR CODE IS HERE ######
        self.cnn_backbone = nn.Sequential(
          nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3),
          nn.BatchNorm2d(16),
          nn.ReLU(),
          nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3),
          nn.BatchNorm2d(16),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
          nn.BatchNorm2d(128),
          nn.ReLU(),
          nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2),
          nn.BatchNorm2d(128),
          nn.ReLU()
        )

        self.final_mlp = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        

    def forward(self, x, lens):
        z = self.cnn_backbone(x[:, None, :, :])
        z = torch.nn.functional.max_pool2d(z, kernel_size=z.size()[2:])[:, :, 0, 0]
        return self.final_mlp(z)

In [47]:
cnn = CNN10()
cnn.to(device);

optim = torch.optim.Adam(cnn.parameters(), lr=3e-4)

In [48]:
train_audio_clfr(cnn, optim, train_dataloader, train_dataset.sr, 
                 data_transform=compute_log_melspectrogram,
                 num_epochs=20)

**Task: to get 0.5 points, select hyperparameters and achieve an accuracy of the model above 0.85 on the validation dataset.**

In [49]:
plot_confusion_matrix(cnn, val_dataloader, train_dataset.sr, device, 
                      data_transform=compute_log_melspectrogram)

## Task 6. SpecAugment Data Augmentation (2 points)

Audio datasets are typically quite small. Our dataset is a good example, with only 4,500 samples in the training set. Training deep networks with a large number of parameters on such datasets often leads to overfitting and a drop in validation and test metrics.

To combat overfitting, data augmentation can be used. For Mel-spectrograms, a technique called **SpecAugment** was developed.

### The core idea of SpecAugment:
It zeroes out parts of the spectrogram along the time and frequency axes:

1. Select several time intervals \({[t^1_i, t^2_i]}\) and set the spectrogram values \(s[t^1_i : t^2_i, :]\) to \(v\).
2. Select several frequency intervals \({[m^1_i, m^2_i]}\) and set the spectrogram values \(s[:, m^1_i : m^2_i]\) to \(v\).

### The value \(v\) can be:
1. `'mean'`: the mean of the spectrogram
2. `'min'`: the minimum of the spectrogram
3. `'max'`: the maximum of the spectrogram
4. `v`: a constant

**Tip:** A detailed description of SpecAugment can be found here: [link](https://neurohive.io/ru/novosti/specaugment-novyj-metod-augmentacii-audiodannyh-ot-google-ai/), which you can use as a reference.

![specaugment](https://neurohive.io/wp-content/uploads/2019/04/image6.png)

### Task:
Implement the **SpecAugment** augmentation.

**The implementation of SpecAugment is worth 1.5 points.**

In [50]:
import random


class SpectAugment:
    def __init__(
        self,
        filling_value = "mean",
        n_freq_masks = 2,
        n_time_masks = 2,
        max_freq = 10,
        max_time = 50,
    ):

        self.filling_value = filling_value
        self.n_freq_masks = n_freq_masks
        self.n_time_masks = n_time_masks
        self.max_freq = max_freq
        self.max_time = max_time

    def __call__(self, spect, lens):
        ### YOUR CODE IS HERE ######
        
        torch.random.manual_seed(4)
        
        num_frequency_bins = spect.shape[1]
        num_time_bins = spect.shape[2]

        for i in range(self.n_freq_masks):
            freq = random.randint(0, self.max_freq)
            bounds_freq_1 = random.randint(0, num_frequency_bins - self.max_freq)
            bounds_freq_2 = bounds_freq_1 + freq
            if self.filling_value == "mean":
                fill_values = torch.reshape((torch.mean(spect, axis=2)[:, bounds_freq_1:bounds_freq_2]).repeat_interleave(1497, dim=1), (32, bounds_freq_2-bounds_freq_1, 1497))
            elif self.filling_value == "min":
                fill_values = torch.reshape((torch.min(spect, axis=2)[:, bounds_freq_1:bounds_freq_2]).repeat_interleave(1497, dim=1), (32, bounds_freq_2-bounds_freq_1, 1497))
            elif self.filling_value == "max":
                fill_values = torch.reshape((torch.max(spect, axis=2)[:, bounds_freq_1:bounds_freq_2]).repeat_interleave(1497, dim=1), (32, bounds_freq_2-bounds_freq_1, 1497))
            else:
                # self.filling_value == "constant"
                fill_values = self.constant

            spect[:, bounds_freq_1 : bounds_freq_2, :] = fill_values
    
    
        for j in range(self.n_time_masks):
            time = random.randint(0, self.max_time)
            bounds_time_1 = random.randint(0, num_time_bins - self.max_time)
            bounds_time_2 = bounds_time_1 + time
            if self.filling_value == "mean":
                fill_values = torch.reshape((torch.mean(spect, axis=1)[:, bounds_time_1:bounds_time_2]).repeat_interleave(64, dim=1), (32, 64, bounds_time_2-bounds_time_1))
            elif self.filling_value == "min":
                fill_values = torch.reshape((torch.min(spect, axis=1)[:, bounds_time_1:bounds_time_2]).repeat_interleave(64, dim=1), (32, 64, bounds_time_2-bounds_time_1))
            elif self.filling_value == "max":
                fill_values = torch.reshape((torch.max(spect, axis=1)[:, bounds_time_1:bounds_time_2]).repeat_interleave(64, dim=1), (32, 64, bounds_time_2-bounds_time_1))
            else:
                # self.filling_value == "constant"
                fill_values = self.constant

            spect[:, :, bounds_time_1 : bounds_time_2] = fill_values
    
        return spect, lens // 256

In [51]:
# применим аугментацию к данным
for batch in train_dataloader:
    break

x = batch["x"].to(device)
lens = batch["len"].to(device)
x_logmel, lens = compute_log_melspectrogram_reference(x, lens, sr=train_dataset.sr, device=device)
x_logmel_augmented, lens = SpectAugment()(x_logmel, lens)

# нарисуем спектрограмму до и после аугментации
plt.figure(figsize=(20, 5))
plt.subplot(2, 1, 1)
plt.title("Original log MelSpectrogram")
plt.axis("off")
plt.imshow(x_logmel[0].cpu().numpy())

plt.subplot(2, 1, 2)
plt.title("Augmented log MelSpectrogram")
plt.axis("off")
plt.imshow(x_logmel_augmented[0].cpu().numpy())

plt.show()

In [52]:
cnn = CNN10()
cnn.to(device);

optim = torch.optim.Adam(cnn.parameters(), lr=3e-4)

In [53]:
# обучим модель на данных с аугментациями
train_audio_clfr(cnn, optim, train_dataloader, train_dataset.sr, 
                 data_transform=compute_log_melspectrogram,
                 augmentation=SpectAugment(),
                 num_epochs=20)

**Task: to get 0.5 points, select the augmentation parameters and achieve an accuracy of the model above 0.9 on the validation dataset.**

In [54]:
plot_confusion_matrix(cnn, val_dataloader, train_dataset.sr, device, 
                      data_transform=compute_log_melspectrogram)