# **GTZAN Music Genre Classification using Audio Spectrogram Transformer (AST)**
Target: >87% accuracy on the GTZAN dataset

In [None]:
pip install torch==2.6.0+cu124 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

In [None]:
pip install datasets==2.16.0

In [None]:
pip install transformers scikit-learn accelerate

In [4]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset, Audio
from transformers import (
    ASTFeatureExtractor,
    ASTForAudioClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

In [5]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
# Load the GTZAN dataset
print("Loading GTZAN dataset...")
dataset = load_dataset("marsyas/gtzan", "all")

Loading GTZAN dataset...


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/1.23G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [7]:
# Get label information
labels = dataset["train"].features["genre"].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)

print(f"Number of genres: {num_labels}")
print(f"Genres: {labels}")

Number of genres: 10
Genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']


In [8]:
# Initialize the AST model and feature extractor
# Using AST-base model fine-tuned on AudioSet
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
print(f"\nLoading model: {model_name}")

feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
model = ASTForAudioClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True  # Important for changing num_labels
)


Loading model: MIT/ast-finetuned-audioset-10-10-0.4593


preprocessor_config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# Move model to device
model = model.to(device)

In [10]:
# Preprocessing function with data augmentation
def preprocess_function(examples):
    """
    Preprocess audio files for AST model
    AST expects 10-second audio clips at 16kHz
    """
    # Load audio at 16kHz
    audio_arrays = [x["array"] for x in examples["audio"]]

    # Process with feature extractor
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=16000,
        padding="max_length",
        max_length=160000,  # 10 seconds at 16kHz
        truncation=True,
        return_tensors="pt"
    )

    # Add labels
    inputs["labels"] = examples["genre"]

    return inputs

In [11]:
# Prepare datasets
print("\nPreparing datasets...")
# Cast audio column to ensure 16kHz sampling rate
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

# Apply preprocessing
encoded_dataset = dataset.map(
    preprocess_function,
    remove_columns=dataset["train"].column_names,
    batched=True,
    batch_size=8,
    num_proc=1
)


Preparing datasets...


Map:   0%|          | 0/999 [00:00<?, ? examples/s]

In [12]:
# Split the train set for validation and testing
# Split the 'train' dataset into 80% train, 10% validation, and 10% test
train_validation_test_split = encoded_dataset["train"].train_test_split(test_size=0.2, seed=42)
train_dataset = train_validation_test_split["train"]

# Split the remaining 20% into half validation and half test (10% validation, 10% test of original train)
validation_test_split = train_validation_test_split["test"].train_test_split(test_size=0.5, seed=42)
eval_dataset = validation_test_split["train"]
test_dataset = validation_test_split["test"]


print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 799
Validation samples: 100
Test samples: 100


In [13]:
# Custom data collator for handling audio data
class AudioDataCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, features):
        # Extract input values and labels
        input_values = [{"input_values": feature["input_values"]} for feature in features]
        labels = [feature["labels"] for feature in features]

        # Pad inputs
        batch = self.feature_extractor.pad(
            input_values,
            padding=True,
            return_tensors="pt"
        )

        # Add labels
        batch["labels"] = torch.tensor(labels, dtype=torch.long)

        return batch

data_collator = AudioDataCollator(feature_extractor)

In [14]:
# Evaluation metrics
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    references = eval_pred.label_ids

    accuracy = accuracy_score(references, predictions)

    return {
        "accuracy": accuracy,
    }

In [15]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [16]:
model_name = model_name.split("/")[-1]
model_name

'ast-finetuned-audioset-10-10-0.4593'

In [17]:
# Training arguments with optimized hyperparameters
training_args = TrainingArguments(
    f"{model_name}-finetuned-gtzan",
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    learning_rate=5e-5,  # Lower learning rate for fine-tuning
    per_device_train_batch_size=8,  # Adjust based on GPU memory
    per_device_eval_batch_size=8,
    num_train_epochs=30,  # More epochs with early stopping
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    push_to_hub=True,
    report_to="none",
    save_total_limit=3,
    fp16=torch.cuda.is_available(),  # Mixed precision training
    gradient_checkpointing=True,  # Memory optimization
    gradient_accumulation_steps=2,  # Effective batch size = 16
    lr_scheduler_type="cosine",  # Cosine learning rate schedule
    seed=42,
)

In [18]:
# Initialize trainer with early stopping
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

In [19]:
# Fine-tune the model
print("\n" + "="*50)
print("Starting training...")
print("="*50)

train_result = trainer.train()


Starting training...


Step,Training Loss,Validation Loss,Accuracy
50,0.8868,0.876064,0.72
100,0.4771,0.763162,0.76
150,0.3415,1.035609,0.72
200,0.2508,0.543248,0.82
250,0.1699,0.6632,0.81
300,0.024,0.874465,0.82
350,0.0353,0.864297,0.79
400,0.0341,0.561386,0.86
450,0.0411,0.623041,0.86
500,0.0345,0.93609,0.76


In [20]:
model_id = "MIT/ast-finetuned-audioset-10-10-0.4593"

kwargs = {
    "dataset_tags": "marsyas/gtzan",
    "dataset": "GTZAN",
    "model_name": f"{model_name}-finetuned-gtzan",
    "finetuned_from": model_id,
    "tasks": "audio-classification",
}

trainer.push_to_hub(**kwargs)

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...3-finetuned-gtzan/training_args.bin: 100%|##########| 5.37kB / 5.37kB            

  ...3-finetuned-gtzan/model.safetensors:  10%|9         | 33.5MB /  345MB            

CommitInfo(commit_url='https://huggingface.co/arsonor/ast-finetuned-audioset-10-10-0.4593-finetuned-gtzan/commit/76f6ae620738f3fd41b7d25a4989f7a63d0ded69', commit_message='End of training', commit_description='', oid='76f6ae620738f3fd41b7d25a4989f7a63d0ded69', pr_url=None, repo_url=RepoUrl('https://huggingface.co/arsonor/ast-finetuned-audioset-10-10-0.4593-finetuned-gtzan', endpoint='https://huggingface.co', repo_type='model', repo_id='arsonor/ast-finetuned-audioset-10-10-0.4593-finetuned-gtzan'), pr_revision=None, pr_num=None)

In [21]:
# Evaluate on test set
print("\n" + "="*50)
print("Evaluating on test set...")
print("="*50)

test_results = trainer.evaluate(eval_dataset=test_dataset)
print(f"\nTest Accuracy: {test_results['eval_accuracy']:.4f}")


Evaluating on test set...



Test Accuracy: 0.8300


In [22]:
# Get detailed predictions for analysis
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

In [23]:
# Print classification report
print("\n" + "="*50)
print("Classification Report:")
print("="*50)
print(classification_report(y_true, y_pred, target_names=labels))


Classification Report:
              precision    recall  f1-score   support

       blues       0.73      1.00      0.84         8
   classical       1.00      1.00      1.00        15
     country       0.86      0.55      0.67        11
       disco       0.88      0.70      0.78        10
      hiphop       0.73      1.00      0.84         8
        jazz       1.00      1.00      1.00         9
       metal       0.80      1.00      0.89         8
         pop       0.56      0.83      0.67         6
      reggae       1.00      0.75      0.86        12
        rock       0.73      0.62      0.67        13

    accuracy                           0.83       100
   macro avg       0.83      0.84      0.82       100
weighted avg       0.85      0.83      0.83       100



In [24]:
# Confusion matrix
print("\n" + "="*50)
print("Confusion Matrix:")
print("="*50)
cm = confusion_matrix(y_true, y_pred)
print("\nRows: Actual labels, Columns: Predicted labels")
print(f"Labels: {labels}")
print(cm)


Confusion Matrix:

Rows: Actual labels, Columns: Predicted labels
Labels: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
[[ 8  0  0  0  0  0  0  0  0  0]
 [ 0 15  0  0  0  0  0  0  0  0]
 [ 2  0  6  1  0  0  0  0  0  2]
 [ 0  0  0  7  0  0  0  3  0  0]
 [ 0  0  0  0  8  0  0  0  0  0]
 [ 0  0  0  0  0  9  0  0  0  0]
 [ 0  0  0  0  0  0  8  0  0  0]
 [ 0  0  0  0  1  0  0  5  0  0]
 [ 0  0  0  0  1  0  0  1  9  1]
 [ 1  0  1  0  1  0  2  0  0  8]]


In [25]:
# Per-genre accuracy
print("\n" + "="*50)
print("Per-Genre Accuracy:")
print("="*50)
for i, genre in enumerate(labels):
    genre_mask = y_true == i
    if genre_mask.sum() > 0:
        genre_acc = (y_pred[genre_mask] == i).mean()
        print(f"{genre:15s}: {genre_acc:.4f}")


Per-Genre Accuracy:
blues          : 1.0000
classical      : 1.0000
country        : 0.5455
disco          : 0.7000
hiphop         : 1.0000
jazz           : 1.0000
metal          : 1.0000
pop            : 0.8333
reggae         : 0.7500
rock           : 0.6154


In [26]:
# Training summary
print("\n" + "="*50)
print("Training Summary:")
print("="*50)
print(f"Total training steps: {train_result.global_step}")
print(f"Final training loss: {train_result.training_loss:.4f}")
print(f"Final test accuracy: {test_results['eval_accuracy']:.4f}")


Training Summary:
Total training steps: 1050
Final training loss: 0.1464
Final test accuracy: 0.8300


In [27]:
# Additional tips for reaching 87% accuracy
print("\n" + "="*50)
print("Tips for Further Improvement (if needed):")
print("="*50)
print("""
1. Data Augmentation: Consider using audiomentations library for:
   - Time stretching
   - Pitch shifting
   - Adding noise
   - Mix-up augmentation

2. Model Ensemble: Train multiple models with different seeds and average predictions

3. Longer Training: Increase num_train_epochs if model hasn't converged

4. Hyperparameter Tuning: Use Optuna or Ray Tune for systematic search

5. Alternative Models to Try:
   - Wav2Vec2 large model
   - HuBERT large model
   - Whisper for feature extraction

6. Post-processing: Use test-time augmentation (TTA) by averaging predictions
   from multiple augmented versions of test samples
""")


Tips for Further Improvement (if needed):

1. Data Augmentation: Consider using audiomentations library for:
   - Time stretching
   - Pitch shifting
   - Adding noise
   - Mix-up augmentation

2. Model Ensemble: Train multiple models with different seeds and average predictions

3. Longer Training: Increase num_train_epochs if model hasn't converged

4. Hyperparameter Tuning: Use Optuna or Ray Tune for systematic search

5. Alternative Models to Try:
   - Wav2Vec2 large model
   - HuBERT large model
   - Whisper for feature extraction

6. Post-processing: Use test-time augmentation (TTA) by averaging predictions
   from multiple augmented versions of test samples

