# 2. Training

In this step, we can go on to the actual training of the model.

## 2.1 Preliminaries

We first have to do some installations and imports.

In [1]:
!pip install evaluate transformers[torch] torchaudio wandb

In [2]:
import json
import torch
import wandb
import random
import evaluate
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from seaborn import heatmap
from collections import Counter 
from datasets import DatasetDict
from torch import tensor, Tensor
from torch.nn import Module, Linear, LayerNorm
from torch.nn.functional import softmax
from torchaudio import transforms
from torch.nn import CrossEntropyLoss
from sklearn.metrics import confusion_matrix
from torch.nn.functional import interpolate, pad
from transformers.trainer_utils import EvalPrediction
from sklearn.utils.class_weight import compute_class_weight
from transformers import ASTFeatureExtractor, ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer

Since I used Weights and Biases (wandb) to monitor the training process, we need to set the
following environment variable to select the correct project.

In [3]:
%env WANDB_PROJECT=genre_classification

And login into my wandb instance. Sorry, no API key leaked here ;)

In [None]:
wandb.login(key="...", host="https://wandb.justinkonratt.com")

Now, we can load the genres and the dataset which have been uploaded from the previous notebook.

In [5]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

In [None]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib_balanced")
preprocessed_dataset

## 2.2 The model

First, load the model and set the column names for the dataset.

In [7]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
labels_name = "labels"
paths_name = "paths"

The AST model works with an `ASTConfig`, which can be used to make changes to the model. Here we can
set the number of classes `num_labels` and their respective names and IDs.

During the course of this work the dropout values `hidden_dropout_prob` and
`attention_probs_dropout_prob` as well as the temporal dimension of the input `max_length` have been
modified. Changing `max_length` resulted in the loss of position embeddings and should thus be
avoided.

Changing the dropout didn't lead to a better performance as you can see in
[notebook 6](./6-results.ipynb). 

In [16]:
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = len(genres)
config.label2id = { genre: idx for idx, genre in enumerate(genres) }
config.id2label = { idx: genre for idx, genre in enumerate(genres) }
config.hidden_dropout_prob = 0.00
config.attention_probs_dropout_prob = 0.00

model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)

## 2.3 Metrics

The following methods are used to calculate the metrics in the evaluation step. Since we are
predicting the genre on every snippet, the important metric is the accuracy of the aggregated
prediction. There are different aggregation methods which can be used: voting, weighted scoring, and
using the mean or the max score of the predictions. They all are reported to see what works best.
[The results](./6-results.ipynb) show that the `mean_pooling_score` works best as well as the
`weighting_score`. Since `mean_pooling_score` has higher results at the end, it is used as primary
performance metric.

Typical classification metrics accuracy (per snippet), precision, recall, and f1 score are also
reported. Furthermore, a confusion matrix is plotted for evaluation.

In [9]:
def calc_aggregated_accuracy(predictions: Tensor, labels: Tensor):
    v_score = 0
    w_score = 0
    a_score = 0
    m_score = 0

    songs_per_genre = [0 for _ in range(len(genres))]
    correctly_predicted_per_genre = [0 for _ in range(len(genres))]

    eval_subset = None
    for subset in preprocessed_dataset.values():
        if len(subset) == predictions.shape[0]:
            eval_subset = subset

    true_labels = []
    predicted_labels = []

    entries_per_song = list(Counter(eval_subset[paths_name]).values())
    song_count = len(entries_per_song)
    start_idx = 0
    for song_entries in entries_per_song:
        label = labels[start_idx]
        true_labels.append(label)
        songs_per_genre[label] += 1
        song_logits = predictions[start_idx:start_idx + song_entries,:]
        start_idx += song_entries

        # voting
        v_pred = song_logits.argmax(dim=1).mode().values.item()
        if v_pred == label:
            v_score += 1

        # weighting
        confidence = song_logits.softmax(dim=1).max(dim=1).values
        weighted_logits = (song_logits.T * confidence).T
        w_pred = weighted_logits.mean(dim=0).argmax().item()
        if w_pred == label:
            w_score += 1

        # average
        a_pred = song_logits.mean(dim=0).argmax().item()
        predicted_labels.append(a_pred)
        if a_pred == label:
            a_score += 1
            correctly_predicted_per_genre[label] += 1

        # max
        m_pred = song_logits.max(dim=0).values.argmax().item()
        if m_pred == label:
            m_score += 1

    cm = confusion_matrix(true_labels, predicted_labels)
    cm_normalized_row = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(8, 6))
    heatmap(cm_normalized_row, xticklabels=genres, yticklabels=genres)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    threshold = 0.1
    confusions = []
    for i in range(cm_normalized_row.shape[0]):
        for j in range(cm_normalized_row.shape[1]):
            if i != j and cm_normalized_row[i, j] > threshold:
                confusions.append((genres[i], genres[j], cm_normalized_row[i, j]))
    confusion_df = pd.DataFrame(confusions, columns=['True Class', 'Predicted Class', 'Confusion Value'])
    confusion_df = confusion_df.sort_values(by='Confusion Value', ascending=False)
    print(confusion_df)

    return {
        "voting_score": v_score / song_count,
        "weighting_score": w_score / song_count,
        "mean_pooling_score": a_score / song_count,
        "max_pooling_score": m_score / song_count,
        **{ f"{genre}_a_accuracy": correctly_predicted_per_genre[idx] / songs_per_genre[idx] for idx, genre in enumerate(genres) }
    }

accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

def compute_metrics(eval_pred: EvalPrediction):
    logits = eval_pred.predictions
    labels = eval_pred.label_ids
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=labels)
    metrics.update(precision.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(calc_aggregated_accuracy(tensor(logits), tensor(labels)))
    return metrics

## 2.4 Augmentation

As for images, augmentation is said to be a good way to increase the robustness of the model for
audio. It is used to add noise, remove some features, or change the length of the signal. This helps
the model to learn different aspects of the data to perform the prediction and generalize better.
For example, Frequency Masking helps the model to focus on different frequencies to learn important
features from bass, mids, and highs equally. Noise and frequency shift can be added for better
generalization and amplitude scaling allows the model to learn the task for different loudness
levels. All these augmentations are directly applied to the spectrograms from the preprocessing.

The only augmentation technique, which should not be used, is time stretching. Even a small stretch
of +/-3 % led to unlearnable samples. As ["Using Time Stretching" in notebook 6](./6-results.ipynb)
shows.

The following augmentation pipeline is used while training with augmentation. All effects are
applied with a probability of 50 %. [Results](./6-results.ipynb) have shown that more augmentation
is better. That's why augmentation is applied to every sample instead to 0 % or 50 %.

In [10]:
class SpecAugmentPipeline:
    def __init__(
            self,
            p=0.5,
            effects_p=0.5,
            time_mask_param=30,
            freq_mask_param=20,
            noise_level=0.05,
            stretch_range=(0.97, 1.03),
            shift_range=3,
            amplitude_range=(0.8, 1.2)
    ):
        self.p = p
        self.effects_p = effects_p
        self.time_mask = transforms.TimeMasking(time_mask_param)
        self.freq_mask = transforms.FrequencyMasking(freq_mask_param)
        self.noise_level = noise_level
        self.stretch_range = stretch_range
        self.shift_range = shift_range
        self.amplitude_range = amplitude_range

    def add_noise(self, spec):
        return spec + torch.randn_like(spec) * self.noise_level
    
    def time_stretch(self, spec):
        spec = spec.unsqueeze(0)
        factor = random.uniform(*self.stretch_range)
        new_steps = int(spec.size(-1) * factor)
        new_spec = interpolate(spec, (spec.size(-2), new_steps), mode="bilinear", align_corners=False).squeeze(0)
        return new_spec.resize_(spec.size(-3), spec.size(-2), spec.size(-1)) if factor >= 1 else pad(new_spec, (0, spec.size(-1) - new_steps))

    def frequency_shift(self, spec):
        shift = random.randint(-self.shift_range, self.shift_range)
        return torch.roll(spec, shifts=shift, dims=-2)

    def amplitude_scaling(self, spec):
        return spec * random.uniform(*self.amplitude_range)

    def __call__(self, spec):
        if random.random() >= self.p:
            return spec
        
        spec = spec.transpose(-1, -2)
        if random.random() < self.effects_p:
            spec = self.time_mask(spec)
        
        if random.random() < self.effects_p:
            spec = self.freq_mask(spec)

        if random.random() < self.effects_p:
            spec = self.add_noise(spec)

        # bad loss behavior
        # if random.random() < self.effects_p:
        #     spec = self.time_stretch(spec)

        if random.random() < self.effects_p:
            spec = self.frequency_shift(spec)

        if random.random() < self.effects_p:
            spec = self.amplitude_scaling(spec)

        return spec.transpose(-1, -2)

The augmentation pipeline is now instantiated and set as transform to the train split to apply it
during training.

In [11]:
aug_pipe = SpecAugmentPipeline(p=1)
def augmentation(sample):
    if model_input_name in sample:
        sample[model_input_name] = aug_pipe(tensor(sample[model_input_name]))
    return sample

preprocessed_dataset["train"].set_transform(augmentation)

## 2.5 Weighted Loss Function

Due to the unbalanced dataset, one idea was to add weights to the loss function to add more focus on
under-represented classes. The weights can be computed using the predefined `compute_class_weight`
function.

In [49]:
labels = preprocessed_dataset["train"][labels_name]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float, device="cuda")
class_weights

They can then be applied using a custom loss function with the trainer. Because we want to predict
a single genre per song or snippet, the default `CrossEntropyLoss` is used in combination with the
class weights.

In [50]:
def weighted_loss_func(outputs, labels, num_items_in_batch):
    logits = outputs.get("logits")
    loss_fct = CrossEntropyLoss(weight=class_weights.to(logits.device))
    return loss_fct(logits, labels)

## 2.6 Training

For training, we use 1 GPU with a batch size of 64 because this fills the GPU best. Logs are sent
to Weights and Biases. We use a default learning rate of `5e-5` and 5 fine-tuning epochs of which 
one tenth is warm up. Evaluation and save strategy is "epoch" to save and evaluate the model after every
epoch while `log_steps` is set to 1 to report every training step. `fp16` is used because it reduces
the memory used on the GPU due to lower precision of the floats while maintaining the same
prediction accuracy.

See the [results](./6-results.ipynb) to verify these decisions.

In [17]:
batch_size = 64
devices = 1
warmup_epochs = 1

training_args = TrainingArguments(
    output_dir="./runs/ast_classifier",
    report_to="wandb",
    run_name="some_name_10",
    learning_rate=5e-5,
    weight_decay=0.00,
    warmup_steps=round(len(preprocessed_dataset["train"]) / batch_size * warmup_epochs),
    push_to_hub=False,
    num_train_epochs=10,
    per_device_train_batch_size=batch_size // devices,
    per_device_eval_batch_size=batch_size // devices,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="mean_pooling_score",
    greater_is_better=True,
    logging_strategy="steps",
    logging_steps=1,
    fp16=True,
    save_total_limit=3
)

Initialize the trainer...

In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["validate"],
    compute_metrics=compute_metrics,
    # compute_loss_func=weighted_loss_func
)

...and train the model.

In [19]:
trainer.train()

After training, we can evaluate the results using the `evaluate` function with `test` as metric key
prefix to log them correctly to `wandb` and finish the run with `wandb.finish()`. To know which
model is used, we can print the `trainer.state.best_model_checkpoint`.

The confusion matrices and the dataframes show the confusions of the model. Don't look at them in
too much detail, we will dive deeper into them for better runs the song prediction notebook. This
run was made with a balanced dataset, of which you can see the creation in the
[next notebook](./3-balanced-dataset.ipynb).

In [20]:
print(trainer.state.best_model_checkpoint)
results = trainer.evaluate(preprocessed_dataset["test"], metric_key_prefix="test")
print(results)

wandb.finish()