In [None]:
# ! pip install -U evaluate transformers accelerate


In [None]:
from datasets import load_dataset
from datasets import Audio
from transformers import AutoFeatureExtractor
from transformers import TrainingArguments
from transformers import Trainer
from transformers import AutoModelForAudioClassification
import evaluate
import numpy as np
from pprint import pprint
import os
import torch
from huggingface_hub import notebook_login, login

notebook_login()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print(torch.__version__, torch.cuda.device_count())


In [None]:
gtzan = load_dataset("marsyas/gtzan", "all")
gtzan


In [None]:
gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True, test_size=0.1)
gtzan


In [None]:
pprint(gtzan["train"][0])


In [None]:
model_id = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, do_normalize=True, return_attention_mask=False)


In [None]:
feature_extractor


In [None]:
gtzan = gtzan.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))


In [None]:
max_duration = 30.0


def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
        return_attention_mask=False,
    )
    return inputs


In [None]:
gtzan_encoded = gtzan.map(
    preprocess_function, remove_columns=["audio", "file"], batched=True, num_proc=2, batch_size=64
)
gtzan_encoded = gtzan_encoded.rename_column("genre", "label")
gtzan_encoded


In [None]:
id2label_fn = gtzan["train"].features["genre"].int2str
id2label = {str(i): id2label_fn(i) for i in range(len(gtzan_encoded["train"].features["label"].names))}
label2id = {v: k for k, v in id2label.items()}

print("id2label:")
pprint(id2label)
print("========================")
print("label2id:")
pprint(label2id)


In [None]:
metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)


In [None]:
def trial(lr, wd):
    num_labels = len(id2label)
    model_name = model_id.split("/")[-1]
    batch_size = 4
    gradient_accumulation_steps = 2
    num_train_epochs = 7

    print("Initializing model ...")
    model = AutoModelForAudioClassification.from_pretrained(
        model_id,
        num_labels=num_labels,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,  # needed for Audio Spectrogram Transformer (AST) model
    )

    training_args = TrainingArguments(
        f"{model_name}-finetuned-gtzan",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=lr,
        weight_decay=wd,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_train_epochs,
        warmup_ratio=0.1,
        logging_steps=5,
        metric_for_best_model="accuracy",
        report_to="none",
        fp16=True,
        push_to_hub=False,
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        save_steps=0.9,
        save_total_limit=1,
    )

    trainer = Trainer(
        model,
        training_args,
        train_dataset=gtzan_encoded["train"],
        eval_dataset=gtzan_encoded["test"],
        tokenizer=feature_extractor,
        compute_metrics=compute_metrics,
    )

    print("Begin training ...")
    print("Dataloader_num_workers:", trainer.args.dataloader_num_workers)
    trainer.train()

    kwargs = {
        "dataset_tags": "marsyas/gtzan",
        "dataset": "GTZAN",
        "model_name": f"{model_name}-finetuned-gtzan",
        "finetuned_from": model_id,
        "tasks": "audio-classification",
    }
    trainer.push_to_hub(**kwargs)


In [None]:
# dataloader_num_workers
# 0 -> 0.57
# 1 -> 0.8
# 2 -> 0.89
# 3 -> 0.88
# 4 -> 0.9
trial(lr=5e-5, wd=0)
