## Fine-tuning Audio Spectogram Transformer to GTZAN

This notebook was inspired by:
https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb#scrollTo=5WMEawzyCEyG

See also the original paper: https://arxiv.org/abs/2104.01778

See also Huggingface: https://huggingface.co/docs/transformers/v4.40.0/en/model_doc/audio-spectrogram-transformer#transformers.ASTConfig

In [None]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification, AutoModelForAudioClassification, TrainingArguments, Trainer, get_scheduler
from datasets import load_dataset, Audio, load_metric
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import torchmetrics

In [None]:
SEED = 42

In [None]:
#https://huggingface.co/datasets/marsyas/gtzan
df_raw = load_dataset("marsyas/gtzan", trust_remote_code=True)

In [None]:
df_raw = df_raw['train'].train_test_split(seed = SEED, shuffle = True,
                                  test_size = .2)

In [None]:
# Obtaining human-readable label
id2label_function = df_raw['train'].features['genre'].int2str
print("genre: ", id2label_function(df_raw['train'][0]['genre']))

In [None]:
sampling_rate_check = None
all_same = True

# Iterating through each sample
for set_name in ['train', 'test']: # Iterating through both sets
    for sample in df_raw[set_name]:
        sampling_rate = sample['audio']['sampling_rate']

        if sampling_rate_check is None:
            sampling_rate_check = sampling_rate
        else:
            if sampling_rate != sampling_rate_check:
                all_same = False
            break

# Printing result
if all_same:
    print(f"All samples have the same sampling rate: {sampling_rate_check} Hz")
else:
    print("The samples in the dataframe have different sampling rates.")

## Feature extraction

We load the Audio Spectogram Transformer which has been pretrained on audioset. This corresponds to model 1. in the section *Pretrained Models* on the Github of the original paper.
https://github.com/YuanGongND/ast/tree/master

In [None]:
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"

We use the feature extractor from Huggingface, which extracts mel-filter bank faetures from raw speech, pads/truncates them to a fixed length, and normalises them.

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

We ensure that the data has the same sampling rate of 16.000 which is the sampling rate of the AST.

In [None]:
sampling_rate = feature_extractor.sampling_rate
print(f'AST sampling rate: {sampling_rate} Hz')

# Resampling data
df_raw = df_raw.cast_column("audio", Audio(sampling_rate = 16000))

The number of data-points in the array of the audio-files is not exactly the same. So in the feature-extractor we set max_length to 30 seconds and truncate.

In [None]:
max_duration = 30.0 # 30 seconds

def preprocess_function(examples):
    # Extracting and saving arrays
    audio_arrays = [x['array'] for x in examples['audio']]

    # Preprocessing audio inputs
    inputs = feature_extractor(audio_arrays,
                              sampling_rate = feature_extractor.sampling_rate,
                              return_tensors="pt", # output pytorch tensors
                              max_length = int(feature_extractor.sampling_rate * max_duration),
                              truncation = True)

    return inputs

In [None]:
df = df_raw.map(preprocess_function,
                   remove_columns = ['audio', 'file'],
                   batched = True,
                   batch_size = 100,
                   num_proc = 1)

In [None]:
print(f"Size of spectogram: {len(df['train'][0]['input_values'][0])}, {len(df['train'][0]['input_values'])}")

In [None]:
# Renaming genre column
df = df.rename_column('genre', 'labels')

# Id to label
id2label = {str(i): id2label_function(i)
           for i in range(len(df['train'].features['labels'].names))}

# Label to id
label2id = {v: k for k, v in id2label.items()}

#Check that is works
integer = 8 # Defining a random int
label = id2label[str(integer)] # Obtaining label

print(f'\nId: {integer}')
print(f'\nLabel: {label}')

## Fine-tune using Torch

We load the pretrained AST model that we are going to fine-tune to classify music genres in GTZAN.

The output complains about mismathing sizes in the pretrained model, which was pretrained on 527 classes, and the model for GTZAN which only has 10 classes. This means that we need to fine-tune the model.

In [None]:
num_labels = len(id2label) # Obtaining the total number of labels

# Loading model
ast_model = AutoModelForAudioClassification.from_pretrained(model_checkpoint,
                                                         num_labels = num_labels,
                                                         label2id=label2id,
                                                         id2label=id2label,
                                                         ignore_mismatched_sizes=True)

In [None]:
df.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ast_model.to(device)

In [None]:
#Set up data-loaders so we can train in batches
train_dataloader = DataLoader(df['train'], shuffle=True, batch_size=8)
eval_dataloader = DataLoader(df['test'], batch_size=8)

In [None]:
# Setting up fine-tuning training hyperparams
optimizer = AdamW(ast_model.parameters(), lr=5e-5) # set Adam as optimizer

num_epochs = 20
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps)

We fine-tune the model using Kaggle's GPU T4 x2 for 20 epochs, which took 1 hour and 20 minutes.

In [None]:
#Fine-tune 
progress_bar = tqdm(total=num_training_steps)

ast_model.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    num_batches = 0

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = ast_model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        num_batches += 1

        progress_bar.update(1)

    # Calculate average loss for the epoch
    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {average_loss:.4f}")

In [None]:
#Evaluate the accuracy
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_labels)

ast_model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = ast_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)

    # Convert predictions and references to CPU if necessary
    predictions = predictions.cpu()
    references = batch["labels"].cpu()

    # Add batch to Accuracy metric
    accuracy.update(predictions, references)

# Compute accuracy
accuracy_result = accuracy.compute()
print("Accuracy:", accuracy_result.item())

The model has an accuracy of 0.89 on the test-set.

In [None]:
torch.save(ast_model.state_dict(), "trained_model.pth")

In [None]:
#Doanload model from Kaggle
# from IPython.display import FileLink
# import os

# os.chdir(r'/kaggle/working')
# %cd /kaggle/working
# FileLink(r'trained_model.pth')

## Test predictions

In [None]:
#Load ast_model
# see here
#https://colab.research.google.com/github/YuanGongND/ast/blob/master/colab/AST_Inference_Demo.ipynb#scrollTo=QiB9y5oKUQBV
ast_model = ASTForAudioClassification()
ast_model.load_state_dict(torch.load("trained_model.pth"))
ast_model.eval()

In [None]:
test_input = feature_extractor(df_raw['test'][0]['audio']['array'], sampling_rate=sampling_rate, return_tensors="pt")

with torch.no_grad():
    processed = ast_model(**test_input, output_attentions=True)

predicted_class_ids = torch.argmax(processed.logits, dim=-1).item()
predicted_label = ast_model.config.id2label[str(predicted_class_ids)]
predicted_label

## Alternative fine-tuning with Huggingface

In [None]:
# batch_size=4

# training_args = TrainingArguments(
#     output_dir = 'ast_gtzan',
#     evaluation_strategy = 'epoch',
#     save_strategy = 'epoch',
#     load_best_model_at_end = True,
#     metric_for_best_model = 'accuracy',
#     learning_rate = 5e-5,
#     seed = SEED,
#     per_device_train_batch_size = batch_size,
#     per_device_eval_batch_size = batch_size,
#     gradient_accumulation_steps = 1,
#     num_train_epochs = 15,
#     warmup_ratio = 0.1,
#     #p16 = True,
#     save_total_limit = 2,
#     report_to = 'none'
#     )

# # Loading `accuracy` metric from the evaluate library
# metric = load_metric("accuracy")

# def compute_metrics(eval_pred):
#     """Computes accuracy on a batch of predictions"""
#     predictions = np.argmax(eval_pred.predictions, axis=1)
#     return metric.compute(predictions=predictions, references=eval_pred.label_ids)

# trainer = Trainer(
#      model=ast_model,
#      args = training_args,
#      train_dataset = df['train'],
#      eval_dataset = df['test'],
#      tokenizer = feature_extractor,
#      compute_metrics = compute_metrics)

# trainer.train()