# Training Music Classification Models

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from datetime import datetime
from pathlib import Path
import random

def set_seed(seed=42):
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

## Training Configuration

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.001 
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 15 

Using device: cuda


In [3]:
# Setup run directory
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = Path(f"../runs/{run_id}")
run_dir.mkdir(parents=True, exist_ok=True)
print(f"Run directory created at: {run_dir}")

changes_file = run_dir / "changes.md"
with open(changes_file, "w") as f:
    f.write(f"# Run {run_id}\n\n")
    f.write("## Configuration\n")
    f.write(f"- Batch Size: {BATCH_SIZE}\n")
    f.write(f"- Learning Rate: {LEARNING_RATE}\n")
    f.write(f"- Epochs: {NUM_EPOCHS}\n")
    f.write(f"- Device: {device}\n")
    f.write(f"- Data Strategy: Chunking (3s chunks, 50% overlap)\n")
    f.write(f"- Augmentation: Noise=0.01, Shift=0.3\n")
    f.write(f"- Optimization: In-memory caching + Mixed Precision (AMP)\n")
    f.write(f"- Stability: Seed=42, Weight Decay=1e-4 (Standard), Gradient Clipping=1.0\n")
    f.write(f"- Data Split: Stratified (Balanced Validation Set)\n\n")
    f.write("## Changes\n")
    f.write("- Increased Dropout in Residual Blocks from 0.2 to 0.3 to combat slight overfitting.\n\n")
    f.write("## Results\n")

Run directory created at: ..\runs\20251202_142134


## Training Function (Single-label Classification)

In [4]:
def mixup_data(x, y, alpha=1.0, device='cuda'):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # mixed precision training
    scaler = torch.amp.GradScaler()
    
    pbar = tqdm(train_loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Apply Mixup
        inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=0.4, device=device)
        
        optimizer.zero_grad()
        
        # pass with mixed precision
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        
        # Backwardprop and optimize
        scaler.scale(loss).backward()
        
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        
        # Weighted accuracy for mixup
        correct += (lam * (predicted == targets_a).float() + (1 - lam) * (predicted == targets_b).float()).sum().item()
        
        pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [5]:
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_labels

In [6]:
def train_model(model, train_loader, val_loader, num_epochs, learning_rate, device, 
                save_path='../models/best_model.pth', changes_file=None):
    
    model = model.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # OneCycleLR Scheduler
    # Steps per epoch is len(train_loader)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=learning_rate, 
        steps_per_epoch=len(train_loader), 
        epochs=num_epochs,
        pct_start=0.3, # Warmup for 30% of training
        div_factor=25.0,
        final_div_factor=1000.0
    )
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        scaler = torch.amp.GradScaler()
        
        pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Apply Mixup
            inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=0.4, device=device)
            
            optimizer.zero_grad()
            
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            # Step scheduler every batch for OneCycleLR
            scheduler.step()
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (lam * (predicted == targets_a).float() + (1 - lam) * (predicted == targets_b).float()).sum().item()
            
            pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})
        
        train_loss = running_loss / total
        train_acc = 100 * correct / total
        
        val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion, device)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # saving best model
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    if changes_file:
        with open(changes_file, "a") as f:
            f.write(f"- Final Train Loss: {history['train_loss'][-1]:.4f}\n")
            f.write(f"- Final Val Loss: {history['val_loss'][-1]:.4f}\n")
            f.write(f"- Final Train Acc: {history['train_acc'][-1]:.2f}%\n")
            f.write(f"- Final Val Acc: {history['val_acc'][-1]:.2f}%\n")

    return history

## Plot Training History

In [7]:
def plot_training_history(history, multi_label=False, save_path=None):
    fig, axes = plt.subplots(1, 2 if not multi_label else 1, figsize=(15, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train Accuracy')
    axes[1].plot(history['val_acc'], label='Val Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

### For Single-label Classification (GTZAN, FMA)

In [9]:
import os
import sys
from pathlib import Path
repo_root = Path.cwd().parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))


%run "./model_cnn.ipynb"

from utils.datasets_gtzan import GTZANDataset, create_dataloaders, GENRES, AudioAugmentation

# Create dataset with in-memory caching
gtzan_root = repo_root / "data" / "gtzan"
dataset = GTZANDataset(str(gtzan_root), cache_to_memory=True)
print(f"GTZAN files: {len(dataset)}")

# Augment data
train_transform = AudioAugmentation(noise_level=0.01, shift_max=0.3)

# Create loaders with Stratified Split and Chunking
# NOTE: With cache_to_memory=True, use num_workers=0 on Windows to avoid 
# pickling the entire cached dataset to worker processes, which causes hangs/OOM.
train_loader, val_loader, test_loader = create_dataloaders(
    dataset, 
    batch_size=BATCH_SIZE, 
    num_workers=0,
    train_transform=train_transform,
    chunk_length_sec=3.0, # Enable chunking
    test_split=0.1 # Create test split
)

model = ComplexCNN(n_classes=10)

history = train_model(
    model, train_loader, val_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    save_path=str(run_dir / 'gtzan_cnn.pth'),
    changes_file=changes_file
)

plot_training_history(history, save_path=str(run_dir / 'training_history.png'))

print(f"Loading best model from {run_dir / 'gtzan_cnn.pth'}...")
model.load_state_dict(torch.load(str(run_dir / 'gtzan_cnn.pth')))

Caching 999 audio files to memory...
Caching complete.
GTZAN files: 999
Created stratified split: 719 train, 180 val, 100 test songs
Applying chunking: 3.0s chunks with 50% overlap
Chunked dataset sizes: 13661 train, 3420 val, 1900 test chunks

Epoch 1/50
--------------------------------------------------


Training:  11%|█         | 47/427 [00:08<01:10,  5.40it/s, loss=2.15, lr=4.01e-5]


KeyboardInterrupt: 

## Evaluation Metrics

In [None]:
def evaluate_model(model, test_loader, device, genre_names=None, changes_file=None, split_name="Test"):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f'Evaluating {split_name}'):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    print(f"\n{split_name} Metrics (Chunk-Level):")
    print(f"Accuracy: {accuracy*100:.2f}%")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    
    if changes_file:
        with open(changes_file, "a") as f:
            f.write(f"- {split_name} Accuracy (Chunk): {accuracy*100:.2f}%\n")
            f.write(f"- {split_name} Precision: {precision:.4f}\n")
            f.write(f"- {split_name} Recall: {recall:.4f}\n")
            f.write(f"- {split_name} F1-Score: {f1:.4f}\n")
    
    return all_preds, all_labels

def evaluate_by_song(model, val_dataset, device, changes_file=None, split_name="Test"):
    """
    Evaluate accuracy by aggregating chunk predictions for each song.
    Val_dataset is ordered by song.
    """
    model.eval()
    correct_songs = 0
    
    if not hasattr(val_dataset, 'num_chunks'):
        print("Dataset does not appear to be a ChunkedDataset. Skipping song-level evaluation.")
        return 0.0

    num_chunks = val_dataset.num_chunks
    total_songs = len(val_dataset) // num_chunks
    
    print(f"\nEvaluating on {total_songs} songs ({split_name}) (aggregating {num_chunks} chunks each)...")
    
    with torch.no_grad():
        for i in tqdm(range(total_songs), desc='Song Eval'):
            chunks = []
            label = None
            
            start_idx = i * num_chunks
            
            for j in range(num_chunks):
                c, l = val_dataset[start_idx + j]
                chunks.append(c)
                label = l 
            
            chunks_tensor = torch.stack(chunks).to(device)
            
            outputs = model(chunks_tensor) # (num_chunks, n_classes)
            
            avg_output = torch.mean(outputs, dim=0)
            pred_label = torch.argmax(avg_output).item()
            
            if pred_label == label:
                correct_songs += 1
                
    song_acc = 100 * correct_songs / total_songs
    print(f"{split_name} Song-Level Accuracy: {song_acc:.2f}%")
    
    if changes_file:
        with open(changes_file, "a") as f:
            f.write(f"- {split_name} Song-Level Accuracy: {song_acc:.2f}%\n")
            
    return song_acc

with open(changes_file, "a") as f:
    f.write("\n--- Validation Set ---\n")
print("\n--- Validation Set Evaluation ---")
evaluate_model(
    model, val_loader, device, genre_names=GENRES, changes_file=changes_file, split_name="Validation"
)
evaluate_by_song(model, val_loader.dataset, device, changes_file=changes_file, split_name="Validation")

with open(changes_file, "a") as f:
    f.write("\n--- Test Set ---\n")
print("\n--- Test Set Evaluation ---")
evaluate_model(
    model, test_loader, device, genre_names=GENRES, changes_file=changes_file, split_name="Test"
)

evaluate_by_song(model, test_loader.dataset, device, changes_file=changes_file, split_name="Test")



--- Validation Set Evaluation ---


Evaluating Validation: 100%|██████████| 107/107 [00:03<00:00, 33.23it/s]
Evaluating Validation: 100%|██████████| 107/107 [00:03<00:00, 33.23it/s]



Validation Metrics (Chunk-Level):
Accuracy: 80.23%
Precision: 0.8043
Recall: 0.8023
F1-Score: 0.7947

Evaluating on 180 songs (Validation) (aggregating 19 chunks each)...


Song Eval: 100%|██████████| 180/180 [00:03<00:00, 57.61it/s]
Song Eval: 100%|██████████| 180/180 [00:03<00:00, 57.61it/s]


Validation Song-Level Accuracy: 85.56%

--- Test Set Evaluation ---


Evaluating Test: 100%|██████████| 60/60 [00:02<00:00, 20.69it/s]
Evaluating Test: 100%|██████████| 60/60 [00:02<00:00, 20.69it/s]



Test Metrics (Chunk-Level):
Accuracy: 77.16%
Precision: 0.7823
Recall: 0.7716
F1-Score: 0.7657

Evaluating on 100 songs (Test) (aggregating 19 chunks each)...


Song Eval: 100%|██████████| 100/100 [00:01<00:00, 60.83it/s]

Test Song-Level Accuracy: 83.00%





83.0