# Music genre classification using viual transformers : VIT

In this notebook, I will be using the visual transformer model to classify the music genre. The model is trained modified GTZAN dataset available on Huggingface datasets ([egtzan_plus](https://huggingface.co/datasets/ghermoso/egtzan_plus)) which contains melspectrogram generated with the `create_spectrogram` function from `create_melspectrogram.py`.

This notebook is a fine-tuning example of the visual transformer model on the music genre classification task. The model is pre-trained on ImageNet-21k dataset and fine-tuned on the eGTZAN+ dataset. 

The model is trained using the `Trainer` API from the 🤗 Transformers library.

# Importing necessary libraries


In [None]:
import torch
from transformers import (
    ViTImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
from PIL import Image
import numpy as np
from datasets import load_metric
from transformers import pipeline
from huggingface_hub import notebook_login

# Model and data preparation


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
ds = load_dataset("ghermoso/egtzan_plus")
labels = ds["train"].features["label"].names

In [None]:
model_path = "google/vit-base-patch16-224-in21k"
processor = ViTImageProcessor.from_pretrained(model_path).to(device)

# Load the model
model = ViTForImageClassification.from_pretrained(
    model_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
).to(device)

In [None]:
def transform(example_batch):
    """Transform the example batch before feeding it to the model."""
    inputs = processor([x for x in example_batch["image"]], return_tensors="pt")
    inputs["label"] = example_batch["label"]
    return inputs


def collate_fn(batch):
    """Postprocess the outputs before returning them."""
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.tensor([x["label"] for x in batch]),
    }


metric = load_metric("accuracy")


def compute_metrics(p):
    """Compute the metrics for the Trainer."""
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1), references=p.label_ids
    )

In [None]:
# Process the dataset with the transform function
prepared_ds = ds.with_transform(transform)

# Training


In [None]:
training_args = TrainingArguments(
    output_dir="./vit-eGTZANplus",  # output directory : change it to your preferred directory
    evaluation_strategy="steps",

    num_train_epochs=16, # number of training epochs : number of times the model will see the dataset

    per_device_train_batch_size=16, # batch size for training  
    per_device_eval_batch_size=16, # batch size for evaluation

    fp16= True if torch.cuda.is_available() else False, # if cuda is available, use fp16 : faster training

    save_steps=10, # save the model every 10 steps
    eval_steps=10, # evaluate the model every 10 steps 
    logging_steps=10, # log the metrics every 10 steps

    learning_rate=2e-4, # learning rate: how much the model will learn from the data every step

    save_total_limit=2, # number of models to save : save the 2 best models

    remove_unused_columns=False, #
    push_to_hub=False, # push your model on hugingface hub (needs authentication with notebook_login() from huggingface_hub)
    load_best_model_at_end=True, # load the best model at the end of training
)

# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=processor,
)

In [None]:
# Train the model
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluate the model


In [None]:
# Evaluate the model
metrics = trainer.evaluate(prepared_ds["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)