In [1]:
!pip install --upgrade --no-cache-dir gdown --quiet
!pip install wandb --quiet
!pip install pytorch-lightning --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.8/258.8 kB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import click
import random
import torch
import gdown
import os
import zipfile
import torchaudio
import wandb
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from typing import List, Dict, Any, Tuple

In [4]:
def load_audio(index, duration) -> Dict:
    audio_path = os.path.join('PMEmo2019/chorus/', str(index) + '.mp3')
    waveform, sample_rate = torchaudio.load(audio_path)
    num_samples = int(duration * sample_rate)
    waveform = waveform[:, :num_samples]
    waveform_mono = torch.mean(waveform, dim=0).unsqueeze(0)
    return {'audio': waveform_mono}


def collate_list_of_dicts(input_set) -> Dict:
    dictionary = {}
    class_list_set = set()

    for item in input_set:
        # class_list_set.update(item['label'])  # in case of list in 'label
        class_list_set.add(item['label'])

    dictionary['classlist'] = list(class_list_set)
    dictionary['audio'] = torch.stack([item['audio'] for item in input_set])
    dictionary['target'] = torch.stack([item['target'] for item in input_set])
    return dictionary


In [5]:
def export_zip_file(zip_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall()


def download_dataset(url, dataset_name, file_name, export):
    full_path = dataset_name
    if not os.path.exists(full_path):
        os.makedirs(full_path)
    !gdown $url
    if export:
      !unzip -qq $file_name

def assign_octant_label(arousal, valence):
    octant_labels = [(['O1', 'O2'], ['O4', 'O3']), (['O7', 'O8'], ['O5', 'O6'])]
    octant = octant_labels[valence < 0.5][arousal < 0.5][abs(valence) < abs(arousal)]
    return octant

In [6]:
class ClassConditionalDataset(Dataset):

    def __getitem__(self, index: int) -> Dict[Any, Any]:
        """
        Grab an item from the dataset. The item returned must be a dictionary.
        """
        raise NotImplementedError

    @property
    def class_list(self) -> List[str]:
        """
        The classlist property returns a list of class labels available in the dataset.
        This property enables users of the dataset to easily access a list of all the classes in the dataset.

        Returns:
            List[str]: A list of class labels available in the dataset.
        """
        raise NotImplementedError

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        """
        Returns a dictionary where the keys are class labels and the values are
        lists of indices in the dataset that belong to that class.
        This property enables users of the dataset to easily access
        examples that belong to specific classes.

        Implement me!

        Returns:
            Dict[str, List[int]]: A dictionary mapping class labels to lists of dataset indices.
        """
        raise NotImplementedError


class EpisodeDataset(Dataset):
    """
        A dataset for sampling few-shot learning tasks from a class-conditional dataset.

    Args:
        dataset (ClassConditionalDataset): The dataset to sample episodes from.
        n_way (int): The number of classes to sample per episode.
            Default: 5.
        n_support (int): The number of samples per class to use as support.
            Default: 5.
        n_query (int): The number of samples per class to use as query.
            Default: 20.
        n_episodes (int): The number of episodes to generate.
            Default: 100.
    """

    def __init__(self,
                 dataset: ClassConditionalDataset,
                 n_way: int = 5,
                 n_support: int = 5,
                 n_query: int = 20,
                 n_episodes: int = 100):
        self.dataset = dataset
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes

    def __getitem__(self, index: int) -> Tuple[Dict, Dict]:
        """Sample an episode from the class-conditional dataset.

        Each episode is a tuple of two dictionaries: a support set and a query set.
        The support set contains a set of samples from each of the classes in the
        episode, and the query set contains another set of samples from each of the
        classes. The class labels are added to each item in the support and query
        sets, and the list of classes is also included in each dictionary.

        Yields:
            Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the support
            set and the query set for an episode.
        """
        # seed the random number generator so we can reproduce this episode
        rng = random.Random(index)

        # sample the list of classes for this episode
        episode_class_list = rng.sample(self.dataset.class_list, self.n_way)

        # sample the support and query sets for this episode
        support, query = [], []
        for c in episode_class_list:
            # grab the dataset indices for this class
            all_indices = self.dataset.class_to_indices[c]

            if not len(all_indices) > 0:
                continue

            # sample the support and query sets for this class
            indices = rng.sample(all_indices, self.n_support + self.n_query)
            items = [self.dataset[i] for i in indices]

            # add the class label to each item
            for item in items:
                item["target"] = torch.tensor(episode_class_list.index(c))
                item["label"] = c  # MTGJamendo items are multiclass, hence restriction to the class of interest

            # split the support and query sets
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:])

        # collate the support and query sets
        support = collate_list_of_dicts(support)
        query = collate_list_of_dicts(query)

        support["class_list"] = episode_class_list
        query["class_list"] = episode_class_list

        return support, query

    def __len__(self):
        return self.n_episodes

    def print_episode(self, support, query):
        """Print a summary of the support and query sets for an episode.

        Args:
            support (Dict[str, Any]): The support set for an episode.
            query (Dict[str, Any]): The query set for an episode.
        """
        print("Support Set:")
        print(f"  Class list: {support['class_list']}")
        print(f"  Audio Shape: {support['audio'].shape}")
        print(f"  Target Shape: {support['target'].shape}")
        print()
        print("Query Set:")
        print(f"  Class list: {query['class_list']}")
        print(f"  Audio Shape: {query['audio'].shape}")
        print(f"  Target Shape: {query['target'].shape}")

In [7]:
class PMEmo(ClassConditionalDataset):
    def __init__(self, download, classes):
        if download:
            pme_mo_data_url = 'https://drive.google.com/uc?id=1UzC3NCDj30j9Ba7i5lkMzWO5gFqSr0OJ'
            pme_mo_readme_url = 'https://drive.google.com/uc?id=1KQ0zjRiBQynnHyVPU7DGpUWvtPmCBOcq'
            download_dataset(pme_mo_readme_url, "PMEmo", "README.txt", False)
            download_dataset(pme_mo_data_url, "PMEmo", "PMEmo2019.zip", True)
        self.classes = classes
        self.annotations_csv = os.path.join('PMEmo2019/annotations/', 'static_annotations.csv')
        self.static_annotations = pd.read_csv(self.annotations_csv)
        for index, row in self.static_annotations.iterrows():
            self.static_annotations.at[index, 'label'] = assign_octant_label(row['Arousal(mean)'], row['Valence(mean)'])

    def __len__(self):
        return self.static_annotations.shape[0]

    def __getitem__(self, index):
        annotations = self.static_annotations[self.static_annotations['musicId'] == index]
        item = load_audio(index, 11)
        item['label'] = annotations['label'].values[0]
        return item

    @property
    def class_list(self) -> List[str]:
        return self.classes

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        class_indices = {}
        for label in self.class_list:
            items = self.static_annotations[self.static_annotations['label'] == label]
            class_indices[label] = items['musicId'].to_list()
        return class_indices

In [8]:
from torch import nn
import torch
import pytorch_lightning as pl
from torchmetrics import Accuracy


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class PrototypicalNet(nn.Module):

    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone

    def forward(self, support: dict, query: dict):
        """
        Forward pass through the protonet.

        Args:
            support (dict): A dictionary containing the support set.
                The support set dict must contain the following keys:
                    - audio: A tensor of shape (n_support, n_channels, n_samples)
                    - label: A tensor of shape (n_support) with label indices
                    - classlist: A tensor of shape (n_classes) containing the list of classes in this episode
            query (dict): A dictionary containing the query set.
                The query set dict must contain the following keys:
                    - audio: A tensor of shape (n_query, n_channels, n_samples)

        Returns:
            logits (torch.Tensor): A tensor of shape (n_query, n_classes) containing the logits

        After the forward pass, the support dict is updated with the following keys:
            - embeddings: A tensor of shape (n_support, n_features) containing the embeddings
            - prototypes: A tensor of shape (n_classes, n_features) containing the prototypes

        The query dict is updated with
            - embeddings: A tensor of shape (n_query, n_features) containing the embeddings

        """
        # compute the embeddings for the support and query sets
        support["embeddings"] = self.backbone(support["audio"])
        query["embeddings"] = self.backbone(query["audio"])

        # group the support embeddings by class
        support_embeddings = []
        for idx in range(len(support["classlist"])):
            embeddings = support["embeddings"][support["target"] == idx]
            support_embeddings.append(embeddings)
        support_embeddings = torch.stack(support_embeddings)

        # compute the prototypes for each class
        prototypes = support_embeddings.mean(dim=1)
        support["prototypes"] = prototypes

        # print("Prototypes Shape: ", prototypes.shape)
        # print("Embeddings Shape: ", query["embeddings"].shape)
        # compute the distances between each query and prototype
        distances = torch.cdist(
            query["embeddings"].unsqueeze(0),
            prototypes.unsqueeze(0),
            p=2
        ).squeeze(0)

        # square the distances to get the sq euclidean distance
        distances = distances ** 2
        logits = -distances

        # return the logits
        return logits


class FewShotLearner(pl.LightningModule):

    def __init__(self,
                 protonet: nn.Module,
                 num_classes,
                 learning_rate: float = 1e-3,
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.protonet = protonet
        self.learning_rate = learning_rate
        self.num_classes = num_classes

        self.loss = nn.CrossEntropyLoss()
        self.metrics = nn.ModuleDict({
            'accuracy': Accuracy(task="multiclass", num_classes=self.num_classes)
        })

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def step(self, batch, batch_idx, tag: str):
        support, query = batch

        logits = self.protonet(support, query)
        loss = self.loss(logits, query["target"])

        output = {"loss": loss}
        for k, metric in self.metrics.items():
          output[k] = metric(torch.argmax(logits.squeeze(), dim=1), query["target"])

        for k, v in output.items():
            self.log(f"{k}/{tag}", v)
        return output

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "test")

In [9]:
from torchaudio.transforms import MelSpectrogram

class PMEmoConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, num_groups, max_pool_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.gn = nn.GroupNorm(num_groups, out_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(max_pool_size)

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x


class Backbone(nn.Module):
    def __init__(self, sample_rate: int):
        super().__init__()
        self.melspec = MelSpectrogram(n_mels=32, sample_rate=sample_rate)

        self.conv1 = PMEmoConvBlock(1, 32, 3, 1, 'same', 8, 2)
        self.conv2 = PMEmoConvBlock(32, 64, 3, 1, 'same', 16, 2)
        self.conv3 = PMEmoConvBlock(64, 128, 3, 1, 'same', 32, 2)
        self.conv4 = PMEmoConvBlock(128, 256, 3, 1, 'same', 64, 2)
        self.conv5 = PMEmoConvBlock(256, 512, 1, 1, 'same', 128, 2)

    def forward(self, x: torch.Tensor):
        assert x.ndim == 3, "Expected a batch of audio samples shape (batch, channels, samples)"

        x = self.melspec(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        # pool over the time dimension
        # squeeze the (t, f) dimensions
        x = x.mean(dim=-1)
        x = x.squeeze(-1).squeeze(-1)  # (batch, 512)

        return x

In [10]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [11]:
wandb.finish()

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

sr = 44100  # sample rate of the audio
n_way = 4  # number of classes per episode
n_support = 5  # number of support examples per class
n_query = 20  # number of samples per class to use as query
n_train_episodes = int(50000)  # number of episodes to generate for training
n_val_episodes = 50  # number of episodes to generate for validation
num_workers = 2  # number of workers to use for data loading

TRAIN_CLASSES = ["O1", "O2", "O8", "O5"]

TEST_CLASSES = ['O6', 'O4']

train_data = PMEmo(True, TRAIN_CLASSES)

val_data = PMEmo(False, TEST_CLASSES)

train_episodes = EpisodeDataset(
    dataset=train_data,
    n_way=n_way,
    n_support=n_support,
    n_query=n_query,
    n_episodes=n_train_episodes
)

val_episodes = EpisodeDataset(
    dataset=val_data,
    n_way=2,
    n_support=5,
    n_query=n_query,
    n_episodes=n_val_episodes
)

train_loader = DataLoader(train_episodes, batch_size=None, num_workers=num_workers)
val_loader = DataLoader(val_episodes, batch_size=None, num_workers=num_workers, persistent_workers=True)

backbone = Backbone(sr)
protonet = PrototypicalNet(backbone)

learner = FewShotLearner(protonet, num_classes=len(TRAIN_CLASSES))

wandb_logger = WandbLogger(project='FSL_PMEmo', job_type='train')
checkpoint_callback = ModelCheckpoint(dirpath='/content/drive/MyDrive/PMEmo_checkpoints',
                                      monitor="step",
                                      mode="max",
                                      filename='latest-{step}',
                                      every_n_train_steps=500)

trainer = pl.Trainer(
    accelerator="auto",
    max_epochs=1,
    log_every_n_steps=1,
    val_check_interval=100,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

trainer.fit(learner, train_loader, val_dataloaders=val_loader)
wandb.finish()

Downloading...
From: https://drive.google.com/uc?id=1KQ0zjRiBQynnHyVPU7DGpUWvtPmCBOcq
To: /content/README.txt
100% 1.52k/1.52k [00:00<00:00, 5.48MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1UzC3NCDj30j9Ba7i5lkMzWO5gFqSr0OJ
From (redirected): https://drive.usercontent.google.com/download?id=1UzC3NCDj30j9Ba7i5lkMzWO5gFqSr0OJ&confirm=t&uuid=b1a1f186-c885-49d7-b4f4-6613da0d70b5
To: /content/PMEmo2019.zip
100% 680M/680M [00:09<00:00, 70.2MB/s]


  octant = octant_labels[valence < 0.5][arousal < 0.5][abs(valence) < abs(arousal)]
  octant = octant_labels[valence < 0.5][arousal < 0.5][abs(valence) < abs(arousal)]
  octant = octant_labels[valence < 0.5][arousal < 0.5][abs(valence) < abs(arousal)]
  octant = octant_labels[valence < 0.5][arousal < 0.5][abs(valence) < abs(arousal)]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'protonet' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['protonet'])`.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mabrzeszczynska[0m ([33

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type             | Params
----------------------------------------------
0 | protonet | PrototypicalNet  | 521 K 
1 | loss     | CrossEntropyLoss | 0     
2 | metrics  | ModuleDict       | 0     
----------------------------------------------
521 K     Trainable params
0         Non-trainable params
521 K     Total params
2.086     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 10. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]