In [None]:
!pip install --upgrade --no-cache-dir gdown --quiet
!pip install wandb --quiet
!pip install pytorch-lightning --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import random
import wandb
import numpy as np
from torch.utils.data import Dataset
from typing import List, Dict, Any, Tuple

In [None]:
def collate_list_of_dicts(input_set) -> Dict:
    dictionary = {}
    class_list_set = set()

    for item in input_set:
        # class_list_set.update(item['label'])  # in case of list in 'label
        class_list_set.add(item['label'])

    dictionary['classlist'] = list(class_list_set)
    dictionary['audio'] = torch.stack([item['audio'] for item in input_set])
    dictionary['target'] = torch.stack([item['target'] for item in input_set])
    return dictionary


In [None]:
class ClassConditionalDataset(Dataset):

    def __getitem__(self, index: int) -> Dict[Any, Any]:
        """
        Grab an item from the dataset. The item returned must be a dictionary.
        """
        raise NotImplementedError

    @property
    def class_list(self) -> List[str]:
        """
        The classlist property returns a list of class labels available in the dataset.
        This property enables users of the dataset to easily access a list of all the classes in the dataset.

        Returns:
            List[str]: A list of class labels available in the dataset.
        """
        raise NotImplementedError

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        """
        Returns a dictionary where the keys are class labels and the values are
        lists of indices in the dataset that belong to that class.
        This property enables users of the dataset to easily access
        examples that belong to specific classes.

        Implement me!

        Returns:
            Dict[str, List[int]]: A dictionary mapping class labels to lists of dataset indices.
        """
        raise NotImplementedError


class EpisodeDataset(Dataset):
    """
        A dataset for sampling few-shot learning tasks from a class-conditional dataset.

    Args:
        dataset (ClassConditionalDataset): The dataset to sample episodes from.
        n_way (int): The number of classes to sample per episode.
            Default: 5.
        n_support (int): The number of samples per class to use as support.
            Default: 5.
        n_query (int): The number of samples per class to use as query.
            Default: 20.
        n_episodes (int): The number of episodes to generate.
            Default: 100.
    """

    def __init__(self,
                 dataset: ClassConditionalDataset,
                 n_way: int = 5,
                 n_support: int = 5,
                 n_query: int = 20,
                 n_episodes: int = 100):
        self.dataset = dataset
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes

    def __getitem__(self, index: int) -> Tuple[Dict, Dict]:
        """Sample an episode from the class-conditional dataset.

        Each episode is a tuple of two dictionaries: a support set and a query set.
        The support set contains a set of samples from each of the classes in the
        episode, and the query set contains another set of samples from each of the
        classes. The class labels are added to each item in the support and query
        sets, and the list of classes is also included in each dictionary.

        Yields:
            Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the support
            set and the query set for an episode.
        """
        # seed the random number generator so we can reproduce this episode
        rng = random.Random(index)

        # sample the list of classes for this episode
        episode_class_list = rng.sample(self.dataset.class_list, self.n_way)

        # sample the support and query sets for this episode
        support, query = [], []
        for c in episode_class_list:
            # grab the dataset indices for this class
            all_indices = self.dataset.class_to_indices[c]

            if not len(all_indices) > 0:
                continue

            # sample the support and query sets for this class
            indices = random.choices([*all_indices], k=self.n_support + self.n_query)
            items = [self.dataset[i] for i in indices]

            # add the class label to each item
            for item in items:
                item["target"] = torch.tensor(episode_class_list.index(c))
                item["label"] = c  # MTGJamendo items are multiclass, hence restriction to the class of interest

            # split the support and query sets
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:])

        # collate the support and query sets
        support = collate_list_of_dicts(support)
        query = collate_list_of_dicts(query)

        support["class_list"] = episode_class_list
        query["class_list"] = episode_class_list

        return support, query

    def __len__(self):
        return self.n_episodes

    def print_episode(self, support, query):
        """Print a summary of the support and query sets for an episode.

        Args:
            support (Dict[str, Any]): The support set for an episode.
            query (Dict[str, Any]): The query set for an episode.
        """
        print("Support Set:")
        print(f"  Class list: {support['class_list']}")
        print(f"  Audio Shape: {support['audio'].shape}")
        print(f"  Target Shape: {support['target'].shape}")
        print()
        print("Query Set:")
        print(f"  Class list: {query['class_list']}")
        print(f"  Audio Shape: {query['audio'].shape}")
        print(f"  Target Shape: {query['target'].shape}")

In [None]:
import os

def main_download(out_path: str = "."):
  for i in range (0, 100):
    if i<10:
      j = "0" + str(i)
    else:
      j = i
    filename = "autotagging_moodtheme_melspecs-" + str(j) + ".tar"
    url = "https://cdn.freesound.org/mtg-jamendo/autotagging_moodtheme/melspecs/" + filename
    print(url)
    if not os.path.exists(filename):
      ! wget $url
    if not os.path.exists(str(j)):
      ! tar -xvf $filename -C $out_path
    ! rm $filename

In [None]:
import csv
from collections import defaultdict

CATEGORIES = ['genre', 'instrument', 'mood/theme']
TAG_HYPHEN = '---'
METADATA_DESCRIPTION = 'TSV file with such columns: TRACK_ID, ARTIST_ID, ALBUM_ID, PATH, DURATION, TAGS'


def get_id(value):
    return int(value.split('_')[1])


def get_length(values):
    return len(str(max(values)))


def read_file(tsv_file):
    tracks = {}
    tags = defaultdict(dict)

    # For statistics
    artist_ids = set()
    albums_ids = set()

    with open(tsv_file) as fp:
        reader = csv.reader(fp, delimiter='\t')
        next(reader, None)  # skip header
        for row in reader:
            track_id = get_id(row[0])
            tracks[track_id] = {
                'artist_id': get_id(row[1]),
                'album_id': get_id(row[2]),
                'path': row[3],
                'duration': float(row[4]),
                'tags': row[5:],  # raw tags, not sure if will be used
            }
            tracks[track_id].update({category: set() for category in CATEGORIES})

            artist_ids.add(get_id(row[1]))
            albums_ids.add(get_id(row[2]))

            for tag_str in row[5:]:
                category, tag = tag_str.split(TAG_HYPHEN)

                if tag not in tags[category]:
                    tags[category][tag] = set()

                tags[category][tag].add(track_id)

                if category not in tracks[track_id]:
                    tracks[track_id][category] = set()

                tracks[track_id][category].update(set(tag.split(",")))

    print("Reading: {} tracks, {} albums, {} artists".format(len(tracks), len(albums_ids), len(artist_ids)))

    extra = {
        'track_id_length': get_length(tracks.keys()),
        'artist_id_length': get_length(artist_ids),
        'album_id_length': get_length(albums_ids)
    }
    return tracks, tags, extra

In [None]:
def load_melspectrogram(path) -> Dict:
    y = torch.from_numpy(np.load(path)[:, :512])
    return {'audio': y}

In [None]:
class MTGJamendo(ClassConditionalDataset):
    def __init__(self, download, outputdir, input_file, class_file, classes):
        if download:
            main_download()
        self.tracks, self.tags, self.extra = read_file(input_file)
        self.class_file = class_file
        self.output_dir = outputdir
        self.classes = classes

    def __len__(self):
        length = 0
        for k, v in self.tracks.items():
            for label in v['tags']:
                if label[13:] in self.classes:
                    length += 1
                    break
        return length

    def __getitem__(self, index):
        item = self.tracks[index]
        data = load_melspectrogram(self.output_dir + "/" + item['path'].replace(".mp3", ".npy"))
        data["label"] = item['tags']
        return data

    @property
    def class_list(self) -> List[str]:
        if self.classes is None:
            with open(self.class_file) as f:
                lines = f.read().splitlines()
            return lines
        else:
            return self.classes

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        return self.tags['mood/theme']

In [None]:
from torch import nn
import torch
import pytorch_lightning as pl
from torchmetrics import Accuracy


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class PrototypicalNet(nn.Module):

    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone

    def forward(self, support: dict, query: dict):
        """
        Forward pass through the protonet.

        Args:
            support (dict): A dictionary containing the support set.
                The support set dict must contain the following keys:
                    - audio: A tensor of shape (n_support, n_channels, n_samples)
                    - label: A tensor of shape (n_support) with label indices
                    - classlist: A tensor of shape (n_classes) containing the list of classes in this episode
            query (dict): A dictionary containing the query set.
                The query set dict must contain the following keys:
                    - audio: A tensor of shape (n_query, n_channels, n_samples)

        Returns:
            logits (torch.Tensor): A tensor of shape (n_query, n_classes) containing the logits

        After the forward pass, the support dict is updated with the following keys:
            - embeddings: A tensor of shape (n_support, n_features) containing the embeddings
            - prototypes: A tensor of shape (n_classes, n_features) containing the prototypes

        The query dict is updated with
            - embeddings: A tensor of shape (n_query, n_features) containing the embeddings

        """
        # compute the embeddings for the support and query sets
        support["embeddings"] = self.backbone(support["audio"])
        query["embeddings"] = self.backbone(query["audio"])

        # group the support embeddings by class
        support_embeddings = []
        for idx in range(len(support["classlist"])):
            embeddings = support["embeddings"][support["target"] == idx]
            support_embeddings.append(embeddings)
        support_embeddings = torch.stack(support_embeddings)

        # compute the prototypes for each class
        prototypes = support_embeddings.mean(dim=1)
        support["prototypes"] = prototypes

        # print("Prototypes Shape: ", prototypes.shape)
        # print("Embeddings Shape: ", query["embeddings"].shape)
        # compute the distances between each query and prototype
        distances = torch.cdist(
            query["embeddings"].unsqueeze(0),
            prototypes.unsqueeze(0),
            p=2
        ).squeeze(0)

        # square the distances to get the sq euclidean distance
        distances = distances ** 2
        logits = -distances

        # return the logits
        return logits


class FewShotLearner(pl.LightningModule):
    def __init__(self,
                 protonet: nn.Module,
                 num_classes,
                 learning_rate: float = 1e-3,
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.protonet = protonet
        self.learning_rate = learning_rate
        self.num_classes = num_classes

        self.loss = nn.CrossEntropyLoss()
        self.metrics = nn.ModuleDict({
            'accuracy': Accuracy(task="multiclass", num_classes=self.num_classes)
        })

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def step(self, batch, batch_idx, tag: str):
        support, query = batch

        logits = self.protonet(support, query)
        loss = self.loss(logits, query["target"])

        output = {"loss": loss}
        for k, metric in self.metrics.items():
          output[k] = metric(logits, query["target"])

        for k, v in output.items():
            self.log(f"{k}/{tag}", v)
        return output

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "test")

In [None]:
class MTGConvBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 kernel_size, stride, padding, max_pool_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.elu = nn.ELU()
        self.maxpool = nn.MaxPool2d(max_pool_size)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.elu(x)
        x = self.maxpool(x)
        return x


class Backbone(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = MTGConvBlock(1, 64, 3, 1, 'same', 2)
        self.conv2 = MTGConvBlock(64, 128, 3, 1, 'same', 2)
        self.conv3 = MTGConvBlock(128, 128, 3, 1, 'same', 2)
        self.conv4 = MTGConvBlock(128, 128, 3, 1, 'same', 2)
        self.conv5 = MTGConvBlock(128, 64, 3, 1, 'same', 4)

    def forward(self, x: torch.Tensor):
        assert x.ndim == 3, "Expected a batch of audio samples shape (batch, channels, samples)"

        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        # pool over the time dimension
        # squeeze the (t, f) dimensions
        x = x.mean(dim=-1)
        x = x.squeeze(-1).squeeze(-1)  # (batch, 512)

        return x

In [None]:
!wandb login

In [None]:
wandb.finish()

In [None]:
TRAIN_CLASSES = [
    'ambiental',
    'background',
    'ballad',
    'calm',
    'cool',
    'dark',
    'deep',
    'dramatic',
    'dream',
    'emotional',
    'energetic',
    'epic',
    'fast',
    'fun',
    'funny',
    'groovy',
    'happy',
    'heavy',
    'hopeful',
    'horror',
    'inspiring',
    'love',
    'meditative',
    'melancholic',
    'mellow',
    'melodic',
    'motivational',
    'nature',
    'party',
    'positive',
    'powerful',
    'relaxing',
    'retro',
    'romantic',
    'sad',
]

TEST_CLASSES = [
    'slow',
    'soft',
    'soundscape',
    'upbeat',
    'uplifting'
]

train_data = MTGJamendo(False,
                        '.',
                        '/content/drive/MyDrive/MTG_Jamendo/autotagging_moodtheme.tsv',
                        '/content/drive/MyDrive/MTG_Jamendo/moodtheme.txt',
                        TRAIN_CLASSES)

val_data = MTGJamendo(False,
                      '.',
                      '/content/drive/MyDrive/MTG_Jamendo/autotagging_moodtheme.tsv',
                      '/content/drive/MyDrive/MTG_Jamendo/moodtheme.txt',
                      TEST_CLASSES)


In [None]:
len(val_data), len(train_data)

In [None]:
n_way = 5  # number of classes per episode
n_support = 5  # number of support examples per class
n_query = 20  # number of samples per class to use as query
n_train_episodes = int(50000)  # number of episodes to generate for training
n_val_episodes = 100  # number of episodes to generate for validation

train_episodes = EpisodeDataset(
    dataset=train_data,
    n_way=n_way,
    n_support=n_support,
    n_query=n_query,
    n_episodes=n_train_episodes
)

val_episodes = EpisodeDataset(
    dataset=val_data,
    n_way=n_way,
    n_support=n_support,
    n_query=n_query,
    n_episodes=n_val_episodes
)

In [None]:
from pytorch_lightning.callbacks import LearningRateFinder


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_validation_end(self, trainer, pl_module):
        if trainer.global_step % self.milestones == 0:
            pl_module.hparams.lr = self.lr_find(trainer, pl_module)
            # wandb.log({"lr": lr})

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

num_workers = 2  # number of workers to use for data loading
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = DataLoader(train_episodes, batch_size=None, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_episodes, batch_size=None, shuffle=True, num_workers=num_workers, persistent_workers=True)

backbone = Backbone()
protonet = PrototypicalNet(backbone).to(DEVICE)

learner = FewShotLearner(protonet, num_classes=n_way, learning_rate=1e-4).to(DEVICE)

wandb_logger = WandbLogger(project='FSL_MTG_Jamendo', job_type='train', log_model="all")
checkpoint_callback = ModelCheckpoint(dirpath='/content/drive/MyDrive/MTG_Jamendo_checkpoints',
                                      monitor="step",
                                      mode="max",
                                      filename='latest-{step}',
                                      every_n_train_steps=100)
lr_monitor = LearningRateMonitor(logging_interval='step')
lr_finder = FineTuneLearningRateFinder(milestones=40, early_stop_threshold=None, max_lr=1)

trainer = pl.Trainer(
    accelerator="auto",
    max_epochs=1,
    log_every_n_steps=1,
    val_check_interval=20,
    logger=wandb_logger,
    callbacks=[checkpoint_callback, lr_monitor]
)

trainer.fit(learner, train_loader, val_dataloaders=val_loader)
wandb.finish()