# Imports

In [11]:
import lightning.pytorch as pl
import torchmetrics
import torch
import torchvision.models as models
import numpy as np
import librosa as lr
from torch import nn

# Constants

In [6]:
CLASS_LABELS = ('cel', 'cla', 'flu', 'gac', 'gel', 'org',
                'pia', 'sax', 'tru', 'vio', 'voi')

# Model

## Last layer

In [7]:
class LastLayer(pl.LightningModule):
    def __init__(self, loss: nn.Module, num_features: int = 2048,
                 num_labels: int = len(CLASS_LABELS),
                 learning_rate: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.learning_rate = learning_rate
        self.fc = nn.Linear(num_features, num_labels)
        self.sigmoid = nn.Sigmoid()

        self.metrics = {
            'loss': loss,
            'accuracy': torchmetrics.Accuracy(task='multilabel',
                                              num_labels=num_labels),
            'precision': torchmetrics.Precision(task="multilabel",
                                                num_labels=num_labels),
            'recall': torchmetrics.Recall(task='multilabel',
                                          num_labels=num_labels),
            'hamming_distance': torchmetrics.HammingDistance(task='multilabel',
                                                             num_labels=num_labels)
        }

    def _get_metrics(self, preds, target, label):
        metrics = {}
        for k, v in self.metrics.items():
            metrics[label+'_'+k] = v(preds, target)
        return metrics

    def forward(self, x):
        x = self.fc(x)
        return self.sigmoid(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        metrics = self._get_metrics(y_hat, y, 'train')
        self.log_dict(metrics, on_step=False, on_epoch=True)

        return metrics['train_loss']

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        metrics = self._get_metrics(y_hat, y, 'validation')
        self.log_dict(metrics, on_step=False, on_epoch=True)

        return metrics['validation_loss']

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        metrics = self._get_metrics(y_hat, y, 'test')
        self.log_dict(metrics, on_step=False, on_epoch=True)

        return metrics['test_loss']

    def configure_optimizers(self):
        return torch.optim.Adam(self.fc.parameters(), lr=self.learning_rate)

# Full model

In [12]:
class FullModel(pl.LightningModule):
    '''Full model, including the resnet, ready for usage.'''

    def __init__(self, last_layer_ckpt: str, melspec_hop_length: int = 256,
                 melspec_n_mels: int = 256):
        super().__init__()

        self.melspec_hop_length = melspec_hop_length
        self.melspec_n_mels = melspec_n_mels

        self.last_layer = LastLayer.load_from_checkpoint(last_layer_ckpt, loss=None)
        self.last_layer.eval()

        if self.last_layer.hparams['num_features'] == 512:
            self.resnet = models.resnet34(weights='DEFAULT')
        elif self.last_layer.hparams['num_features'] == 2048:
            self.resnet = models.resnet50(weights='DEFAULT')
        else:
            error_text = f"Provided checkpoint expects {self.last_layer.hparams['num_features']}. Don't know which base model to use."
            raise AttributeError(error_text)

        self.resnet.fc = torch.nn.Identity()
        self.resnet.eval()

    def predict(self, signal: np.ndarray, sample_rate: int) -> dict:
        melspec = torch.Tensor(lr.feature.melspectrogram(y=signal, sr=sample_rate,
                                                         hop_length=self.melspec_hop_length,
                                                         n_mels=self.melspec_n_mels))
        normalized = (melspec - torch.min(melspec)) / (torch.max(melspec) - torch.min(melspec))
        tripled = torch.stack((normalized, normalized, normalized), axis=0)
        processed = torch.unsqueeze(tripled, dim=0)

        resnet_features = self.resnet(processed)
        probs = self.last_layer(torch.squeeze(resnet_features, dim=0))

        output = {}
        for idx, label in enumerate(CLASS_LABELS):
            output[label] = float(probs[idx])

        return output