## Fine-tuning Audio Spectogram Transformer to GTZAN

This notebook was inspired by:
https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb#scrollTo=5WMEawzyCEyG

See also the original paper: https://arxiv.org/abs/2104.01778

See also Huggingface: https://huggingface.co/docs/transformers/v4.40.0/en/model_doc/audio-spectrogram-transformer#transformers.ASTConfig

In [None]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification, AutoModelForAudioClassification, TrainingArguments, Trainer
from datasets import load_metric, Dataset, load_dataset, Audio
from datasets.dataset_dict import DatasetDict
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import torchmetrics
import torchaudio
import wandb


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sn
import os
import re
from IPython.display import FileLink
from sklearn.metrics import accuracy_score

In [None]:
#set up genre names and their codes
genre_names = [
    "blues",
    "classical",
    "country",
    "disco",
    "hiphop",
    "jazz",
    "metal",
    "pop",
    "reggae",
    "rock",
]
genre_codes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# Import these into your working script to make sure that we all have the same codes
id2label = {id_: label for id_, label in zip(genre_codes, genre_names)}
label2id = {label: id_ for label, id_ in zip(genre_names, genre_codes)}

In [None]:
SEED = 42

## Load Datasets with Huggingface

This section is used to load datasets with the command *load_dataset*. Read more here https://huggingface.co/docs/datasets/audio_dataset#audiofolder

If you run it on Kaggle, please upload your datasets and specify the directories correctly.

In [None]:
os.environ['DATA_BASELINE_TRAIN'] = '/kaggle/input/data-train-val-test/data_train_val_test'
os.environ['DATA_BASELINE_NOISY_TRAIN'] = '/kaggle/input/data-noisy-train-val-test/data_noisy_train_val_test'
os.environ['DATA_BASELINE_GENERATED_TRAIN'] = '/kaggle/input/aml24mst/data_train_val_test'

In [None]:
df_baseline = load_dataset(os.getenv('DATA_BASELINE_TRAIN'))
test_set = df_baseline.pop('test')
df_baseline = df_baseline.shuffle()

In [None]:
df_noise = load_dataset(os.getenv('DATA_BASELINE_NOISY_TRAIN'))
test_set = df_noise.pop('test')
df_noise = df_noise.shuffle()

In [None]:
df_gen = load_dataset(os.getenv('DATA_BASELINE_GENERATED_TRAIN'))
test_set = df_gen.pop('test')
df_gen = df_gen.shuffle()

## Feature extraction

We load the Audio Spectogram Transformer which has been pretrained on audioset. This corresponds to model 1. in the section *Pretrained Models* on the Github of the original paper.
https://github.com/YuanGongND/ast/tree/master

In [None]:
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"

We use the feature extractor from Huggingface, which extracts mel-filter bank faetures from raw speech, pads/truncates them to a fixed length, and normalises them.

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

The number of data-points in the array of the audio-files is not exactly the same. So in the feature-extractor we set max_length to 30 seconds and truncate.

The feature extractor also ensures that the sampling rate if 16.000 for all the samples.

In [None]:
sampling_rate = feature_extractor.sampling_rate
print(f'AST sampling rate: {sampling_rate} Hz')

# Resampling data
df_baseline = df_baseline.cast_column("audio", Audio(sampling_rate = 16000))
df_noise = df_noise.cast_column("audio", Audio(sampling_rate = 16000))
df_gen = df_gen.cast_column("audio", Audio(sampling_rate = 16000))

In [None]:
max_duration = 30.0 # 30 seconds

def preprocess_function(examples):
    # Extracting and saving arrays
    audio_arrays = [x['array'] for x in examples['audio']]

    # Preprocessing audio inputs
    inputs = feature_extractor(audio_arrays,
                              sampling_rate = feature_extractor.sampling_rate,
                              return_tensors="pt", # output pytorch tensors
                              max_length = int(feature_extractor.sampling_rate * max_duration),
                              truncation = True)

    return inputs

In [None]:
def apply_preprocess(df):
    df = df.map(preprocess_function,
                    remove_columns = ['audio'],
                    batched = True,
                    batch_size = 100)

    return df

In [None]:
df_baseline = apply_preprocess(df_baseline)

In [None]:
df_noise = apply_preprocess(df_noise)

In [None]:
df_gen = apply_preprocess(df_gen)

In [None]:
print(f"Size of spectogram, Train: {len(df_baseline['train'][0]['inpu_values'][0])}, {len(df_baseline['train'][0]['input_values'])}")
print(f"Size of spectogram, Val: {len(df_baseline['validation'][0]['input_values'][0])}, {len(df_baseline['validation'][0]['input_values'])}")

## Fine-tune

*Skip the section if you have already fine-tuned the models - Please go to inference section*

We load the pretrained AST model that we are going to fine-tune to classify music genres in GTZAN.

The output complains about mismathing sizes in the pretrained model, which was pretrained on 527 classes, and the model for GTZAN which only has 10 classes. This means that we need to fine-tune the model.

In [None]:
num_labels = len(id2label) # Obtaining the total number of labels

# Loading model
ast_model = AutoModelForAudioClassification.from_pretrained(model_checkpoint,
                                                         num_labels = num_labels,
                                                         label2id=label2id,
                                                         id2label=id2label,
                                                         ignore_mismatched_sizes=True)

In [None]:
wandb.login()
%env WANDB_LOG_MODEL=true

In [None]:
batch_size=4 #ran out of memory with batch_size=8

training_args = TrainingArguments(
    output_dir = 'ast_gtzan',
    evaluation_strategy = 'epoch',
    save_strategy = 'epoch',
    load_best_model_at_end = True,
    metric_for_best_model = 'accuracy',
    learning_rate = 5e-5,
    seed = SEED,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    gradient_accumulation_steps = 1,
    num_train_epochs = 15,
    warmup_ratio = 0.1,
    fp16 = True,
    save_total_limit = 2,
    report_to = 'wandb',
    run_name = 'gen_model'
    )

In [None]:
# Loading `accuracy` metric from the evaluate library
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

We fine-tune the models using Kaggle's GPU T4 x2 for 20 epochs, which took 1 hour and 20 minutes per model.

In [None]:
trainer_baseline = Trainer(
     model=ast_model,
     args = training_args,
     train_dataset = df_baseline['train'],
     eval_dataset = df_baseline['validation'],
     tokenizer = feature_extractor,
     compute_metrics = compute_metrics)

trainer_baseline.train()

In [None]:
trainer_noise = Trainer(
     model=ast_model,
     args = training_args,
     train_dataset = df_noise['train'],
     eval_dataset = df_noise['validation'],
     tokenizer = feature_extractor,
     compute_metrics = compute_metrics)

trainer_noise.train()

In [None]:
trainer_gen = Trainer(
     model=ast_model,
     args = training_args,
     train_dataset = df_gen['train'],
     eval_dataset = df_gen['validation'],
     tokenizer = feature_extractor,
     compute_metrics = compute_metrics)

trainer_gen.train()

Saves the model on kaggle.

In [None]:
save_path = 'ast_finetune_baseline'
!mkdir {save_path}
trainer_baseline.save_model(save_path)
feature_extractor.save_pretrained(save_path)
!ls {save_path}

In [None]:
save_path = 'ast_finetune_noise'
!mkdir {save_path}
trainer_noise.save_model(save_path)
feature_extractor.save_pretrained(save_path)
!ls {save_path}

In [None]:
save_path = 'ast_finetune_gen'
!mkdir {save_path}
trainer_gen.save_model(save_path)
feature_extractor.save_pretrained(save_path)
!ls {save_path}

**If you use kaggle remember to click save version**

Alternatively if you want to download the model from kaggle, please run

In [None]:
#from IPython.display import FileLink
#import os
#!zip -r ast_finetuned.zip {save_path}
#os.chdir(r'/kaggle/working/')
#FileLink(r'ast_finetuned.zip')

## Inference

We start by loading the fine-tuned models.

In [None]:
num_labels = len(id2label) # Obtaining the total number of labels

In [None]:
# Loading model
ast_model_baseline = AutoModelForAudioClassification.from_pretrained('ast_finetune_baseline',
                                                         num_labels = num_labels,
                                                         label2id=label2id,
                                                         id2label=id2label,
                                                         ignore_mismatched_sizes=True)

In [None]:
# Loading model
ast_model_noise = AutoModelForAudioClassification.from_pretrained('ast_finetune_noise',
                                                         num_labels = num_labels,
                                                         label2id=label2id,
                                                         id2label=id2label,
                                                         ignore_mismatched_sizes=True)

In [None]:
# Loading model
ast_model_gen = AutoModelForAudioClassification.from_pretrained('ast_finetune_gen',
                                                         num_labels = num_labels,
                                                         label2id=label2id,
                                                         id2label=id2label,
                                                         ignore_mismatched_sizes=True)

We load the test-set.

In [None]:
df_baseline_infe = load_dataset(os.getenv('DATA_BASELINE_TRAIN'))
test_set = df_baseline_infe.pop('test')

In [None]:
test_set

We make a function that loops over the test-set and gets argmax prediction.

In [None]:
def get_argmax_inference(df, model):
    predictions = []
    progress_bar = tqdm(total=df.num_rows)
    # Initialize tqdm with the total number of iterations
    for i in range(df.num_rows):
        test_input = feature_extractor(df[i]['audio']['array'], sampling_rate=sampling_rate, return_tensors="pt")

        with torch.no_grad():
            processed = model(**test_input, output_attentions=True)

        predicted_class_ids = torch.argmax(processed.logits, dim=-1).item()
        predicted_label = id2label[predicted_class_ids]
        predictions.append(predicted_label)
        
        progress_bar.update(1)
        
    return predictions

In [None]:
baseline_predictions = get_argmax_inference(test_set, ast_model_baseline)

In [None]:
noise_predictions = get_argmax_inference(test_set, ast_model_noise)

In [None]:
gen_predictions = get_argmax_inference(test_set, ast_model_gen)

We then save the predictions.

In [None]:
import pickle

def save_predictions_to_pickle(predictions, filename):
    with open(filename, 'wb') as file:
        pickle.dump(predictions, file)

# Save baseline predictions as a pickle file
save_predictions_to_pickle(baseline_predictions, 'baseline_predictions.pkl')

# Save noise predictions as a pickle file
save_predictions_to_pickle(noise_predictions, 'noise_predictions.pkl')

# Save genre predictions as a pickle file
save_predictions_to_pickle(gen_predictions, 'gen_predictions.pkl')

In [None]:
os.chdir(r'/kaggle/working/')
FileLink(r'baseline_predictions.pkl')

In [None]:
FileLink(r'noise_predictions.pkl')

In [None]:
FileLink(r'gen_predictions.pkl')

And finally make confusion matrices and accuracy

In [None]:
#true labels
test_true = [id2label[x] for x in test_set['label']]

In [None]:
def confusion_matrix_(predictions, save_name=''):
    genres = list(label2id.keys())

    grid_cm = pd.DataFrame(confusion_matrix(test_true, predictions),
                           index=genres,
                           columns=genres)
    plt.figure(figsize=(8,6))
    plt.title(f"Confusion matrix of {save_name}")
    # Specify the colormap as 'viridis'
    sn.heatmap(grid_cm, annot=True, cmap="viridis")
    # Add labels to y-axis and x-axis
    plt.ylabel('True')
    plt.xlabel('Predicted')
    plt.show()
    plt.savefig(f'confusion_matrix_{save_name}.png')

In [None]:
confusion_matrix_(baseline_predictions, 'baseline')

In [None]:
confusion_matrix_(noise_predictions, 'noise')

In [None]:
confusion_matrix_(gen_predictions, 'generated')

In [None]:
os.chdir(r'/kaggle/working/')
FileLink(r'confusion_matrix_baseline.png')

In [None]:
FileLink(r'confusion_matrix_noise.png')

In [None]:
FileLink(r'confusion_matrix_generated.png')

In [None]:
print(f"Accuracy Baseline: {accuracy_score(test_true, baseline_predictions)}")
print(f"Accuracy Noise: {accuracy_score(test_true, noise_predictions)}")
print(f"Accuracy Generated: {accuracy_score(test_true, gen_predictions)}")

## Pytorch

Fine-tune using pytorch. May be useful.

In [None]:
# df.set_format("torch")

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# ast_model.to(device)

# #Set up data-loaders so we can train in batches
# train_dataloader = DataLoader(df['train'], shuffle=True, batch_size=8)
# eval_dataloader = DataLoader(df['test'], batch_size=8)

# # Setting up fine-tuning training hyperparams
# optimizer = AdamW(ast_model.parameters(), lr=5e-5) # set Adam as optimizer

# num_epochs = 20
# num_training_steps = num_epochs * len(train_dataloader)
# lr_scheduler = get_scheduler(
#     name="linear",
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=num_training_steps)

# #Fine-tune 
# progress_bar = tqdm(total=num_training_steps)

# ast_model.train()
# for epoch in range(num_epochs):
#     total_loss = 0.0
#     num_batches = 0

#     for batch in train_dataloader:
#         batch = {k: v.to(device) for k, v in batch.items()}
#         outputs = ast_model(**batch)
#         loss = outputs.loss
#         loss.backward()

#         optimizer.step()
#         lr_scheduler.step()
#         optimizer.zero_grad()

#         total_loss += loss.item()
#         num_batches += 1

#         progress_bar.update(1)

#     # Calculate average loss for the epoch
#     average_loss = total_loss / num_batches
#     print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {average_loss:.4f}")
    
# #Evaluate the accuracy
# accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_labels)

# ast_model.eval()
# for batch in eval_dataloader:
#     batch = {k: v.to(device) for k, v in batch.items()}
#     with torch.no_grad():
#         outputs = ast_model(**batch)

#     logits = outputs.logits
#     predictions = torch.argmax(logits, dim=-1)

#     # Convert predictions and references to CPU if necessary
#     predictions = predictions.cpu()
#     references = batch["labels"].cpu()

#     # Add batch to Accuracy metric
#     accuracy.update(predictions, references)

# # Compute accuracy
# accuracy_result = accuracy.compute()
# print("Accuracy:", accuracy_result.item())

# #torch.save(ast_model.state_dict(), "ast_finetuned_gtzan.pth")

## Load baseline dataset manually

Another way to load the data. May be useful.

In [None]:
# train_files = []
# test_files = []

# if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
#     path = '/kaggle/input/data-train-test/data_train_test'
# else:
#     path = './data_train_test/'

# sampling_rate = 0

# for set in ['train', 'test']:
#     path_set = os.path.join(path, set) 
#     for genre_name, genre_code in zip(genre_names, genre_codes):
#         # Construct full path to genre directory
#         genre_dir = os.path.join(path_set, genre_name)
        
#         # Check if the directory exists
#         if not os.path.isdir(genre_dir):
#             print(f"Directory '{genre_dir}' does not exist.")
#             continue
        
#         # Loop over files in the genre directory
#         for file_name in os.listdir(genre_dir):
#             # Construct full path to file
#             file_path = os.path.join(genre_dir, file_name)

#             # Check if it's a file
#             if os.path.isfile(file_path):
#                 wave_form = 0
#                 try:
#                     waveform, sample_rate_file = torchaudio.load(file_path)
#                     waveform_np = waveform.numpy()
#                 except Exception:
#                     print(f'Could not load {file_name}')
#                     continue
                
#                 if set == 'train':
#                     train_files.append(file_name)
#                 else:
#                     test_files.append(file_name)
                    
#                 #check that sampling rate is the same for all audio files
#                 if sampling_rate == 0:
#                     sampling_rate = sample_rate_file
#                     print(f'Sampling rate is {sampling_rate} hz')
#                 else:
#                     if sampling_rate != sample_rate_file:
#                         print("Sampling rates do not match")
#             else:
#                 print(f"'{file_path}' is not a file.")

In [None]:
#https://huggingface.co/datasets/marsyas/gtzan
# df = load_dataset("marsyas/gtzan", trust_remote_code=True)

In [None]:
# def file_name_search(file_name):
#     '''
#     Find file-name in the path of the file. 
#     '''
#     if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
#         pattern = r'([^/]+\.wav)'
#     else:
#         pattern = r'(.+\\)?(.+\.wav)'
    
#     match = re.search(pattern, file_name)
#     return match.group(1)

In [None]:
#Now we define our train dataset by filtering only file names that occur in our list of training file names. 
#We do the same for test file names.

# df_train = df.filter(lambda x: file_name_search(x['file']) in train_files)
# df_test = df.filter(lambda x: file_name_search(x['file']) in test_files)

In [None]:
#Double check to see if we got the correct files
# train_files_ = [file_name_search(x) for x in df_train['train']['file']]
# test_files_ = [file_name_search(x) for x in df_test['train']['file']]

# if (train_files == train_files_) == True & (test_files == test_files_) == True:
#     print('Files in df_train and df_test match the training and test files from the folder data_train_test')

In [None]:
# df_train = df_train.shuffle()
# df_test = df_test.shuffle()