# Training Music Classification Models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from datetime import datetime
from pathlib import Path
import random

def set_seed(seed=42):
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

## Training Configuration

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.0001  # Lowered from 0.001
NUM_EPOCHS = 20
EARLY_STOPPING_PATIENCE = 10

Using device: cuda


In [None]:
# Setup run directory
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = Path(f"../runs/{run_id}")
run_dir.mkdir(parents=True, exist_ok=True)
print(f"Run directory created at: {run_dir}")

# Create changes.md
changes_file = run_dir / "changes.md"
with open(changes_file, "w") as f:
    f.write(f"# Run {run_id}\n\n")
    f.write("## Configuration\n")
    f.write(f"- Batch Size: {BATCH_SIZE}\n")
    f.write(f"- Learning Rate: {LEARNING_RATE}\n")
    f.write(f"- Epochs: {NUM_EPOCHS}\n")
    f.write(f"- Device: {device}\n")
    f.write(f"- Augmentation: Noise=0.005, Shift=0.2\n")
    f.write(f"- Optimization: In-memory caching + Mixed Precision (AMP)\n")
    f.write(f"- Stability: Seed=42, Weight Decay=1e-4, Gradient Clipping=1.0\n")
    f.write(f"- Data Split: Stratified (Balanced Validation Set)\n\n")
    f.write("## Changes\n")
    f.write("- Implemented Stratified Split to ensure validation set has balanced genre distribution.\n")
    f.write("- Lowered Learning Rate to 1e-4 to prevent oscillation.\n")
    f.write("- Added Weight Decay (1e-4) and Gradient Clipping (1.0) for regularization and stability.\n")
    f.write("- Set fixed seed for reproducibility.\n\n")
    f.write("## Results\n")

Run directory created at: ..\runs\20251127_143203


## Training Function (Single-label Classification)

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Use mixed precision training
    scaler = torch.amp.GradScaler()
    
    pbar = tqdm(train_loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        scaler.scale(loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [5]:
def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_labels

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, learning_rate, device, 
                save_path='../models/best_model.pth'):
    """Complete training loop with early stopping."""
    model = model.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    # Added weight decay for regularization
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                                       patience=5)
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), save_path)
            print(f"✓ Model saved to {save_path}")
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    return history

## Training Function (Multi-label Classification)

In [7]:
def train_multilabel(model, train_loader, val_loader, num_epochs, learning_rate, device,
                     save_path='../models/best_model_multilabel.pth'):
    """Training loop for multi-label classification."""
    model = model.to(device)
    
    # Loss and optimizer (BCE for multi-label)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                       patience=5)
    
    history = {
        'train_loss': [],
        'val_loss': []
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # Training
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
            pbar.set_postfix({'loss': loss.item()})
        
        train_loss /= train_batches
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_batches = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc='Validation')
            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_batches += 1
                pbar.set_postfix({'loss': loss.item()})
        
        val_loss /= val_batches
        
        scheduler.step(val_loss)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print(f"✓ Model saved to {save_path}")
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    return history

## Plot Training History

In [8]:
def plot_training_history(history, multi_label=False, save_path=None):
    """Plot training history."""
    fig, axes = plt.subplots(1, 2 if not multi_label else 1, figsize=(15, 5))
    
    if not multi_label:
        # Loss plot
        axes[0].plot(history['train_loss'], label='Train Loss')
        axes[0].plot(history['val_loss'], label='Val Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True)
        
        # Accuracy plot
        axes[1].plot(history['train_acc'], label='Train Accuracy')
        axes[1].plot(history['val_acc'], label='Val Accuracy')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy (%)')
        axes[1].set_title('Training and Validation Accuracy')
        axes[1].legend()
        axes[1].grid(True)
    else:
        # Loss plot only for multi-label
        axes.plot(history['train_loss'], label='Train Loss')
        axes.plot(history['val_loss'], label='Val Loss')
        axes.set_xlabel('Epoch')
        axes.set_ylabel('Loss')
        axes.set_title('Training and Validation Loss')
        axes.legend()
        axes.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

### For Single-label Classification (GTZAN, FMA)

In [None]:
# Train SimpleCNN on GTZAN

# Ensure repository root is on sys.path
import os
import sys
from pathlib import Path
repo_root = Path.cwd().parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Import model (prefer module; fallback to notebook)
try:
    from model_cnn import SimpleCNN
except ModuleNotFoundError:
    print("Model module not found; loading from notebook via %run ...")
    %run "./04_model_cnn.ipynb"

# Import dataset from stable utils module (Windows-safe)
try:
    from utils.datasets_gtzan import GTZANDataset, create_dataloaders, GENRES, AudioAugmentation
except ModuleNotFoundError:
    print("Dataset module not found; loading from notebook via %run ...")
    %run "./01_data_loading_gtzan.ipynb"

# Create dataset with in-memory caching
gtzan_root = repo_root / "data" / "gtzan"
dataset = GTZANDataset(str(gtzan_root), cache_to_memory=True)
print(f"GTZAN files: {len(dataset)}")

# Define augmentation
train_transform = AudioAugmentation(noise_level=0.005, shift_max=0.2)

# Create loaders with Stratified Split
# NOTE: With cache_to_memory=True, we must use num_workers=0 on Windows to avoid 
# pickling the entire cached dataset to worker processes, which causes hangs/OOM.
train_loader, val_loader = create_dataloaders(
    dataset, 
    batch_size=BATCH_SIZE, 
    num_workers=0,
    train_transform=train_transform
)

# Create model
model = SimpleCNN(n_classes=10)

# Train
history = train_model(
    model, train_loader, val_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device,
    save_path=str(run_dir / 'gtzan_cnn.pth')
)

# Plot results
plot_training_history(history, save_path=str(run_dir / 'training_history.png'))

# Append results to changes.md
with open(changes_file, "a") as f:
    f.write(f"- Final Train Loss: {history['train_loss'][-1]:.4f}\n")
    f.write(f"- Final Val Loss: {history['val_loss'][-1]:.4f}\n")
    f.write(f"- Final Train Acc: {history['train_acc'][-1]:.2f}%\n")
    f.write(f"- Final Val Acc: {history['val_acc'][-1]:.2f}%\n")

Model module not found; loading from notebook via %run ...
SimpleCNN:
SimpleCNN(
  (mel_spec): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
  (amplitude_to_db): AmplitudeToDB()
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(128, 256, ke

  model.load_state_dict(torch.load(path))


Caching complete.
GTZAN files: 999

Epoch 1/20
--------------------------------------------------

Epoch 1/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:14<00:00,  1.73it/s, loss=1.83, acc=26.4]
Training: 100%|██████████| 25/25 [00:14<00:00,  1.73it/s, loss=1.83, acc=26.4]
Validation: 100%|██████████| 7/7 [00:02<00:00,  2.98it/s, loss=2.56, acc=15]  



Train Loss: 2.0003, Train Acc: 26.41%
Val Loss: 2.3758, Val Acc: 15.00%
✓ Model saved to ..\runs\20251127_143203\gtzan_cnn.pth

Epoch 2/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.37it/s, loss=1.57, acc=34.2]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.37it/s, loss=1.57, acc=34.2]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s, loss=1.64, acc=44.5]



Train Loss: 1.7454, Train Acc: 34.17%
Val Loss: 1.5376, Val Acc: 44.50%
✓ Model saved to ..\runs\20251127_143203\gtzan_cnn.pth

Epoch 3/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.41it/s, loss=1.39, acc=40.9]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.41it/s, loss=1.39, acc=40.9]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=3.82, acc=26.5]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=3.82, acc=26.5]


Train Loss: 1.5838, Train Acc: 40.93%
Val Loss: 3.2618, Val Acc: 26.50%

Epoch 4/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=1.38, acc=45.4]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=1.38, acc=45.4]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=1.85, acc=38.5]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=1.85, acc=38.5]


Train Loss: 1.4835, Train Acc: 45.43%
Val Loss: 1.8503, Val Acc: 38.50%

Epoch 5/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=1.41, acc=50.2]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=1.41, acc=50.2]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s, loss=1.37, acc=55.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s, loss=1.37, acc=55.5]


Train Loss: 1.4023, Train Acc: 50.19%
Val Loss: 1.2775, Val Acc: 55.50%
✓ Model saved to ..\runs\20251127_143203\gtzan_cnn.pth

Epoch 6/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.37it/s, loss=1.44, acc=50.8]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.37it/s, loss=1.44, acc=50.8]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.64it/s, loss=2.13, acc=38]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.64it/s, loss=2.13, acc=38]


Train Loss: 1.3443, Train Acc: 50.81%
Val Loss: 1.7121, Val Acc: 38.00%

Epoch 7/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=1.1, acc=54.7]  
Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=1.1, acc=54.7]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=1.34, acc=49]   
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=1.34, acc=49]  


Train Loss: 1.2730, Train Acc: 54.69%
Val Loss: 1.4645, Val Acc: 49.00%

Epoch 8/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.41it/s, loss=0.729, acc=61.1]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.41it/s, loss=0.729, acc=61.1]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.67it/s, loss=1.71, acc=50]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.67it/s, loss=1.71, acc=50]


Train Loss: 1.1742, Train Acc: 61.08%
Val Loss: 1.5564, Val Acc: 50.00%

Epoch 9/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=0.972, acc=60.6]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=0.972, acc=60.6]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.68it/s, loss=1.23, acc=57.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.68it/s, loss=1.23, acc=57.5]


Train Loss: 1.1670, Train Acc: 60.58%
Val Loss: 1.2712, Val Acc: 57.50%
✓ Model saved to ..\runs\20251127_143203\gtzan_cnn.pth

Epoch 10/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=1.59, acc=63]   
Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=1.59, acc=63]   
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.72it/s, loss=3.35, acc=20]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.72it/s, loss=3.35, acc=20]  


Train Loss: 1.1346, Train Acc: 62.95%
Val Loss: 2.8610, Val Acc: 20.00%

Epoch 11/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.869, acc=65.3]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.869, acc=65.3]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=2.2, acc=26.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=2.2, acc=26.5]


Train Loss: 1.0635, Train Acc: 65.33%
Val Loss: 2.4168, Val Acc: 26.50%

Epoch 12/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=0.844, acc=66.3]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s, loss=0.844, acc=66.3]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=1.75, acc=45]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=1.75, acc=45]  


Train Loss: 1.0260, Train Acc: 66.33%
Val Loss: 1.6349, Val Acc: 45.00%

Epoch 13/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.848, acc=64.8]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.848, acc=64.8]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s, loss=2.6, acc=34]   
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s, loss=2.6, acc=34]   


Train Loss: 1.0109, Train Acc: 64.83%
Val Loss: 2.3223, Val Acc: 34.00%

Epoch 14/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.38it/s, loss=0.62, acc=68.5] 
Training: 100%|██████████| 25/25 [00:10<00:00,  2.38it/s, loss=0.62, acc=68.5]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.68it/s, loss=2.05, acc=36]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.68it/s, loss=2.05, acc=36]  


Train Loss: 0.9402, Train Acc: 68.46%
Val Loss: 2.1658, Val Acc: 36.00%

Epoch 15/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=1.01, acc=68.5] 
Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=1.01, acc=68.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=1.35, acc=56]   
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.70it/s, loss=1.35, acc=56]  


Train Loss: 0.9207, Train Acc: 68.46%
Val Loss: 1.2195, Val Acc: 56.00%
✓ Model saved to ..\runs\20251127_143203\gtzan_cnn.pth

Epoch 16/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.895, acc=69.7]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.895, acc=69.7]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=1.53, acc=35.5]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.69it/s, loss=1.53, acc=35.5]


Train Loss: 0.8945, Train Acc: 69.71%
Val Loss: 1.9044, Val Acc: 35.50%

Epoch 17/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.618, acc=71.2]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.42it/s, loss=0.618, acc=71.2]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.72it/s, loss=1.94, acc=54.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.72it/s, loss=1.94, acc=54.5]


Train Loss: 0.8925, Train Acc: 71.21%
Val Loss: 1.4159, Val Acc: 54.50%

Epoch 18/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.572, acc=75.3]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.572, acc=75.3]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=2.42, acc=42]  
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=2.42, acc=42] 


Train Loss: 0.7945, Train Acc: 75.34%
Val Loss: 2.3863, Val Acc: 42.00%

Epoch 19/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.928, acc=73.7]
Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.928, acc=73.7]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=1.43, acc=61.5] 
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.71it/s, loss=1.43, acc=61.5]


Train Loss: 0.7921, Train Acc: 73.72%
Val Loss: 1.2310, Val Acc: 61.50%

Epoch 20/20
--------------------------------------------------


Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.76, acc=75.1] 
Training: 100%|██████████| 25/25 [00:10<00:00,  2.43it/s, loss=0.76, acc=75.1]
Validation: 100%|██████████| 7/7 [00:01<00:00,  3.72it/s, loss=1.82, acc=50]   



Train Loss: 0.7694, Train Acc: 75.09%
Val Loss: 1.6303, Val Acc: 50.00%


### For Multi-label Classification (MTAT)

In [None]:
# Example: Train DeepCNN on MTAT
# Uncomment and adapt to your dataset

# from notebooks.model_cnn import DeepCNN
# from notebooks.data_loading_mtat import MTATDataset, create_dataloaders

# # Create dataset
# dataset = MTATDataset(MTAT_AUDIO_PATH, MTAT_ANNOTATIONS_PATH, top_tags=50)
# train_loader, val_loader = create_dataloaders(dataset, batch_size=BATCH_SIZE)

# # Create model
# model = DeepCNN(n_classes=50)

# # Train
# history = train_multilabel(
#     model, train_loader, val_loader,
#     num_epochs=NUM_EPOCHS,
#     learning_rate=LEARNING_RATE,
#     device=device,
#     save_path='../models/mtat_cnn.pth'
# )

# # Plot results
# plot_training_history(history, multi_label=True)

## Evaluation Metrics

In [11]:
def evaluate_model(model, test_loader, device, genre_names=None):
    """Evaluate model and print detailed metrics."""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating'):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    print(f"\nTest Metrics:")
    print(f"Accuracy: {accuracy*100:.2f}%")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    
    return all_preds, all_labels