In [None]:
import sys

sys.path.append("../src/")
import os
import numpy as np
import pandas as pd
from typing import Iterable
import librosa
import torch
import torchaudio
import torchvision

THRESHOLD_LENGTH = int(0.05 * 1e7)
BATH_SIZE = 4
RESIZE_SIZE = 256
# os.environ["PYTORCH_CUDA_ALLOC_CONF"]="max_split_size_mb:512"

In [None]:
df = pd.read_csv("../data/metadata.csv")
df["path"] = df["path"].apply(lambda x: "../" + x)
df.head()

In [None]:
df["dst_path"] = df["path"].apply(lambda x: x.replace("/data/", "/img_data/"))

In [None]:
# plt.xticks(np.linspace(0, 3, 30)*1e7, rotation=90);

In [None]:
class Normalize(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        return (x - torch.min(x)) / (torch.max(x) - torch.min(x))


class SpectrogramsDataset(torch.utils.data.Dataset):
    def __init__(self, paths, labels: str, threshold_len: int, transforms=None):
        super().__init__()
        self.paths = list(paths)
        self.labels = torch.tensor(list(labels), dtype=torch.float32)
        self.threshold_len = threshold_len
        # TODO move transforms somewhere
        self.transforms = transforms

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.path_to_spectrogram(self.paths[idx])
        x = torch.tensor(x)
        if self.transforms:
            x = self.transforms(x)
        return x, self.labels[idx]

    def path_to_spectrogram(self, path: str) -> np.array:
        waveform, _ = torchaudio.load(path)
        if waveform.size(1) < self.threshold_len:
            padding = int(self.threshold_len - waveform.size(1))
            waveform = torch.cat((waveform, torch.zeros((1, padding))), dim=1)
        elif waveform.size(1) > self.threshold_len:
            waveform = waveform[:, : self.threshold_len]
        return librosa.amplitude_to_db(
            np.abs(librosa.stft(waveform.numpy())), ref=np.max
        )


transforms = torch.nn.Sequential(
    torchvision.transforms.Resize((RESIZE_SIZE, RESIZE_SIZE)), Normalize()
)

train_df = df[df["subset"] == "train"]
val_df = df[df["subset"] == "validation"]

train_ds = SpectrogramsDataset(
    train_df["path"], train_df["label"], THRESHOLD_LENGTH, transforms
)
val_ds = SpectrogramsDataset(
    val_df["path"], val_df["label"], THRESHOLD_LENGTH, transforms
)

train_dl = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=BATH_SIZE)
val_dl = torch.utils.data.DataLoader(val_ds, shuffle=True, batch_size=BATH_SIZE * 2)

In [None]:
class SimpleClassifier(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(
                1, 16, kernel_size=7, stride=1, padding=3
            ),  # 16, 1024, 1024
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),  # 16, 512, 512
            torch.nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),  # 32, 512, 512
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),  # 32, 512, 512
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 64, 256, 256
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),  # 64, 128, 128
        )
        self.flatten = torch.nn.Flatten()
        self.fc = torch.nn.Linear(64 * int(RESIZE_SIZE / 8) ** 2, num_classes)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.softmax(x)
        return x

In [None]:
train_ds[0][0]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
from typing import Tuple
from tqdm import tqdm


def count_correct(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(y_pred, dim=1)
    return (preds == y_true).float().sum()


def validate(
    model: torch.nn.Module, loss_fn: torch.nn.CrossEntropyLoss, dataloader
) -> Tuple[torch.Tensor, torch.Tensor]:
    loss = 0
    correct = 0
    all = 0
    for X_batch, y_batch in dataloader:
        y_pred = model(X_batch.cuda())
        all += len(y_pred)
        loss += loss_fn(y_pred, y_batch.to(device)).sum()
        correct += count_correct(y_pred, y_batch.to(device))
        torch.cuda.empty_cache()
    return loss / all, correct / all


def fit(
    model: torch.nn.Module,
    optimiser: torch.optim.Optimizer,
    loss_fn: torch.nn.CrossEntropyLoss,
    train_dl,
    val_dl,
    epochs: int,
    print_metrics: str = True,
):
    for epoch in tqdm(range(epochs)):
        for X_batch, y_batch in train_dl:
            y_pred = model(X_batch.to(device))
            loss = loss_fn(y_pred, y_batch.to(device))

            loss.backward()
            optimiser.step()
            optimiser.zero_grad()
            torch.cuda.empty_cache()

        if print_metrics:
            model.eval()
            with torch.no_grad():
                train_loss, train_acc = validate(
                    model=model, loss_fn=loss_fn, dataloader=train_dl
                )
                val_loss, val_acc = validate(
                    model=model, loss_fn=loss_fn, dataloader=val_dl
                )
                print(
                    f"Epoch {epoch}: "
                    f"train loss = {train_loss:.3f} (acc: {train_acc:.3f}), "
                    f"validation loss = {val_loss:.3f} (acc: {val_acc:.3f})"
                )

    with torch.no_grad():
        train_loss, train_acc = validate(
            model=model, loss_fn=loss_fn, dataloader=train_dl
        )
        val_loss, val_acc = validate(model=model, loss_fn=loss_fn, dataloader=val_dl)
        print(
            "Training ended: "
            f"train loss = {train_loss:.3f} (acc: {train_acc:.3f}), "
            f"validation loss = {val_loss:.3f} (acc: {val_acc:.3f})"
        )

In [None]:
model = SimpleClassifier(66).to(device)
fit(
    model=model,
    optimiser=torch.optim.Adam(model.parameters()),
    loss_fn=torch.nn.CrossEntropyLoss(),
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=5,
    print_metrics=True,
)