# 5. Training Songs

This notebook aims to add a different classification head to the AST model to train with all six
snippets of a song. The training and parameters are the same as in the previous training notebook,
thus only the model modifications and the results will be analyzed in detail.

In [None]:
!pip install evaluate transformers[torch] torchaudio wandb

In [None]:
import json
import torch
import wandb
import random
import evaluate
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from seaborn import heatmap
from datasets import DatasetDict
from torch import tensor, Tensor, matmul
from torch.nn import Module, Linear, LayerNorm, Embedding
from torch.nn.functional import softmax
from torchaudio import transforms
from torch.nn import CrossEntropyLoss
from sklearn.metrics import confusion_matrix
from torch.nn.functional import interpolate, pad
from transformers.trainer_utils import EvalPrediction
from sklearn.utils.class_weight import compute_class_weight
from transformers import ASTFeatureExtractor, ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput

In [None]:
%env WANDB_PROJECT=genre_classification

In [None]:
wandb.login(key="...", host="https://wandb.justinkonratt.com")

In [None]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

Load the song dataset instead of the default one...

In [None]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib_songs")
preprocessed_dataset

In [None]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
labels_name = "labels"
paths_name = "paths"

In [None]:
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = len(genres)
config.label2id = { genre: idx for idx, genre in enumerate(genres) }
config.id2label = { idx: genre for idx, genre in enumerate(genres) }
config.hidden_dropout_prob = 0.00
config.attention_probs_dropout_prob = 0.00

model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)

## The Model

A few modifications have been tested here:

1. Adding temporal attention to the model and weight the snippet CLS tokens to obtain the song CLS.
2. Using the means of the six CLS tokens of the snippets instead of weighting them.
3. Adding six position embeddings (one to each snippet CLS) and use temporal attention.
4. Add a CLS-embedding and seven position embeddings, then use simple self attention and use the
   resulting CLS token as the song's CLS token.

For this to work, an own model is used to wrap AST. Additionally, the classifier of the model is
changed by a custom torch module. The batch size is reduced during training to fit on the GPU. 12
was chosen because 12 * 6 = 72, which is the number of snippets per batch. Thus, the batch has a
shape of (12, 6, 1024, 128). To give it into the model, the batch is flattened to (72, 1024, 128),
such that each snippet is processed separately. After that, the original shape is restored for the
classification.

Despite all results seemed quite good, taking the mean of the CLS tokens worked but. But none of
those methods was better than classifying each snippet separately and aggregating the classification
results. [See here](./6-results.ipynb)

In [None]:
class TemporalAttention(Module):
    def __init__(self):
        super(TemporalAttention, self).__init__()
        self.attention_weights = Linear(config.hidden_size, 1)

    def forward(self, cls_tokens):
        scores = self.attention_weights(cls_tokens)
        attention_weights = softmax(scores, dim=-2)
        return torch.sum(attention_weights * cls_tokens, dim=-2)
    
class SelfAttention(Module):
    def __init__(self):
        super(SelfAttention, self).__init__()
        self.h_sqrt = config.hidden_size ** 0.5
        self.query = Linear(config.hidden_size, config.hidden_size)
        self.key = Linear(config.hidden_size, config.hidden_size)
        self.value = Linear(config.hidden_size, config.hidden_size)

    def forward(self, cls_tokens):
        Q = self.query(cls_tokens)
        K = self.key(cls_tokens)
        V = self.value(cls_tokens)
        scores = matmul(Q, K.transpose(-1, -2)) / self.h_sqrt
        attention_weights = softmax(scores, dim=-2)
        return cls_tokens + matmul(attention_weights, V)

class AggregateSnipptesClassifier(Module):
    def __init__(self, snippets_per_song: int):
        super(AggregateSnipptesClassifier, self).__init__()
        self.snippets_per_song = snippets_per_song
        self.layernorm = LayerNorm(normalized_shape=(config.hidden_size,), eps=1e-12, elementwise_affine=True)
        # self.position_embeddings = Embedding(1 + snippets_per_song, config.hidden_size)
        # self.cls_embedding = Embedding(1, config.hidden_size)
        self.attention = TemporalAttention() # SelfAttention()
        self.classifier = Linear(in_features=config.hidden_size, out_features=config.num_labels, bias=True)

    def forward(self, cls_tokens: Tensor):
        cls_tokens = cls_tokens.reshape(cls_tokens.shape[0] // self.snippets_per_song, self.snippets_per_song, cls_tokens.shape[1])
        normalized_cls_tokens = self.layernorm(cls_tokens)
        # cls_token = self.cls_embedding(torch.arange(1, device=cls_tokens.device)).expand(cls_tokens.shape[0], 1, cls_tokens.shape[2])
        # cls_tokens = torch.cat([cls_token, normalized_cls_tokens], dim=1)
        song_cls_token = self.attention(normalized_cls_tokens) # normalized_cls_tokens.mean(dim=-2)
        # song_cls_tokens = self.attention(cls_tokens) # self.attention(self.position_embeddings(torch.arange(self.snippets_per_song, device=cls_tokens.device)) + normalized_cls_tokens)
        return self.classifier(song_cls_token)
        # return self.classifier(song_cls_tokens[:,0,:])

class ASTSnippetModel(Module):
    def __init__(self, snippets_per_song: int = 6):
        super(ASTSnippetModel, self).__init__()
        self.model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
        self.model.classifier = AggregateSnipptesClassifier(snippets_per_song)
        self.loss = CrossEntropyLoss()

    def forward(self, input_values: Tensor, labels: Tensor):
        # input_values shape: (batch_size, snippets_per_song, time, freq)
        batch = input_values.reshape(input_values.shape[0] * input_values.shape[1], input_values.shape[2], input_values.shape[3])
        logits = self.model(batch).logits
        loss = self.loss(logits, labels)
        return SequenceClassifierOutput(loss=loss, logits=logits)

model = ASTSnippetModel()

## Different Metrics

We don't need all the aggregated metrics here because we already have song predictions. Thus, the
main metric is now `accuracy`.

In [None]:
def calc_aggregated_accuracy(predictions: Tensor, labels: Tensor):
    predicted_labels = predictions.argmax(dim=1)
    cm = confusion_matrix(labels, predicted_labels)
    cm_normalized_row = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]    

    plt.figure(figsize=(8, 6))
    heatmap(cm_normalized_row, xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    threshold = 0.1
    confusions = []
    for i in range(cm_normalized_row.shape[0]):
        for j in range(cm_normalized_row.shape[1]):
            if i != j and cm_normalized_row[i, j] > threshold:
                confusions.append((genres[i], genres[j], cm_normalized_row[i, j]))
    confusion_df = pd.DataFrame(confusions, columns=['True Class', 'Predicted Class', 'Confusion Value'])
    confusion_df = confusion_df.sort_values(by='Confusion Value', ascending=False)
    print(confusion_df)

    correctly_predicted_per_genre = [0] * len(genres)
    songs_per_genre = [0] * len(genres)
    for idx, label in enumerate(labels):
        if label == predicted_labels[idx]:
            correctly_predicted_per_genre[label] += 1
        songs_per_genre[label] += 1

    return {
        **{ f"{genre}_accuracy": correctly_predicted_per_genre[idx] / songs_per_genre[idx] for idx, genre in enumerate(genres) }
    }

accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

def compute_metrics(eval_pred: EvalPrediction):
    logits = eval_pred.predictions
    labels = eval_pred.label_ids
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=labels)
    metrics.update(precision.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(calc_aggregated_accuracy(tensor(logits), tensor(labels)))
    return metrics

In [None]:
class SpecAugmentPipeline:
    def __init__(
            self,
            p=0.5,
            effects_p=0.5,
            time_mask_param=30,
            freq_mask_param=20,
            noise_level=0.05,
            stretch_range=(0.97, 1.03),
            shift_range=3,
            amplitude_range=(0.8, 1.2)
    ):
        self.p = p
        self.effects_p = effects_p
        self.time_mask = transforms.TimeMasking(time_mask_param)
        self.freq_mask = transforms.FrequencyMasking(freq_mask_param)
        self.noise_level = noise_level
        self.stretch_range = stretch_range
        self.shift_range = shift_range
        self.amplitude_range = amplitude_range

    def add_noise(self, spec):
        return spec + torch.randn_like(spec) * self.noise_level
    
    def time_stretch(self, spec):
        spec = spec.unsqueeze(0)
        factor = random.uniform(*self.stretch_range)
        new_steps = int(spec.size(-1) * factor)
        new_spec = interpolate(spec, (spec.size(-2), new_steps), mode="bilinear", align_corners=False).squeeze(0)
        return new_spec.resize_(spec.size(-3), spec.size(-2), spec.size(-1)) if factor >= 1 else pad(new_spec, (0, spec.size(-1) - new_steps))

    def frequency_shift(self, spec):
        shift = random.randint(-self.shift_range, self.shift_range)
        return torch.roll(spec, shifts=shift, dims=-2)

    def amplitude_scaling(self, spec):
        return spec * random.uniform(*self.amplitude_range)

    def __call__(self, spec):
        if random.random() >= self.p:
            return spec
        
        spec = spec.transpose(-1, -2)
        if random.random() < self.effects_p:
            spec = self.time_mask(spec)
        
        if random.random() < self.effects_p:
            spec = self.freq_mask(spec)

        if random.random() < self.effects_p:
            spec = self.add_noise(spec)

        # bad loss behavior
        # if random.random() < self.effects_p:
        #     spec = self.time_stretch(spec)

        if random.random() < self.effects_p:
            spec = self.frequency_shift(spec)

        if random.random() < self.effects_p:
            spec = self.amplitude_scaling(spec)

        return spec.transpose(-1, -2)

In [None]:
aug_pipe = SpecAugmentPipeline(p=1)
def augmentation(sample):
    if model_input_name in sample:
        sample[model_input_name] = aug_pipe(tensor(sample[model_input_name]))
    return sample

preprocessed_dataset["train"].set_transform(augmentation)

In [None]:
batch_size = 12
devices = 1
epochs = 5
warmup_epochs = epochs / 10

training_args = TrainingArguments(
    output_dir="./runs/ast_classifier",
    report_to="wandb",
    run_name="songs_self_attn",
    learning_rate=5e-5,
    weight_decay=0.00,
    warmup_steps=round(len(preprocessed_dataset["train"]) / batch_size * warmup_epochs),
    push_to_hub=False,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size // devices,
    per_device_eval_batch_size=batch_size // devices,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    logging_strategy="steps",
    logging_steps=1,
    fp16=True,
    save_total_limit=3
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["validate"],
    compute_metrics=compute_metrics,
)

## Training with the temporal attention

This training shows the best results from the training with temporal weighting of the snippets'
CLS-Tokens. On each evaluation, a confusion matrix and the most common confusions are plotted.
You can clearly see, that the highly represented classes are the most predicted ones in the
beginning, but the model improves over time and also learns the other genres.

In [None]:
trainer.train()

On the test set, we can see below which genres are often confused. For example, Folk is all the time
classified as Country, which is quite similar. Also, Epic and Lo-Fi are classified as Chillstep
half of the time. Again a very similar Genre. Thus, we can see that the model has learned important
characteristics of the music, but still lacks to focus on smaller aspects. Maybe spectrograms are
not the best feature for this task. Also, the dataset is quite small and opinionated, which further
decreases the model's ability to learn from it. The test recall shows a score of around 70 %, making
it a rather bad model for this task, although the accuracy is around 86.2 %. But this is mainly
because of the highly represented genres.

In [None]:
print(trainer.state.best_model_checkpoint)
results = trainer.evaluate(preprocessed_dataset["test"], metric_key_prefix="test")
print(results)

wandb.finish()