# Training Analysis

Visualize and interpret training runs for the StepMania difficulty classifier.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 1. Load Checkpoint

In [None]:
CHECKPOINT_DIR = Path('../checkpoints')

# List available checkpoints
checkpoints = list(CHECKPOINT_DIR.glob('*.pt'))
print("Available checkpoints:")
for cp in sorted(checkpoints):
    print(f"  {cp.name}")

In [None]:
# Load best checkpoint
checkpoint_path = CHECKPOINT_DIR / 'best_val_loss.pt'
if not checkpoint_path.exists():
    checkpoint_path = CHECKPOINT_DIR / 'last.pt'

checkpoint = torch.load(checkpoint_path, map_location='cpu')
print(f"Loaded: {checkpoint_path.name}")
print(f"Epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"Best val loss: {checkpoint.get('best_val_loss', 'N/A'):.4f}")

history = checkpoint.get('history', {})

## 2. Loss Curves

In [None]:
if history:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss
    ax = axes[0]
    epochs = range(1, len(history['train_loss']) + 1)
    ax.plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
    ax.plot(epochs, history['val_loss'], 'r-', label='Validation', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training and Validation Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Find best epoch
    best_epoch = np.argmin(history['val_loss']) + 1
    best_val_loss = min(history['val_loss'])
    ax.axvline(best_epoch, color='green', linestyle='--', alpha=0.7, label=f'Best (epoch {best_epoch})')
    ax.scatter([best_epoch], [best_val_loss], color='green', s=100, zorder=5)
    
    # Accuracy
    ax = axes[1]
    ax.plot(epochs, history['train_acc'], 'b-', label='Train', linewidth=2)
    ax.plot(epochs, history['val_acc'], 'r-', label='Validation', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title('Training and Validation Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    best_val_acc = max(history['val_acc'])
    best_acc_epoch = np.argmax(history['val_acc']) + 1
    ax.axvline(best_acc_epoch, color='green', linestyle='--', alpha=0.7)
    ax.scatter([best_acc_epoch], [best_val_acc], color='green', s=100, zorder=5)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
    print(f"Best validation accuracy: {best_val_acc:.4f} at epoch {best_acc_epoch}")
else:
    print("No history found in checkpoint")

## 3. Overfitting Analysis

In [None]:
if history and len(history['train_loss']) > 5:
    # Compute generalization gap
    train_loss = np.array(history['train_loss'])
    val_loss = np.array(history['val_loss'])
    gap = val_loss - train_loss
    
    fig, ax = plt.subplots(figsize=(10, 5))
    epochs = range(1, len(gap) + 1)
    ax.plot(epochs, gap, 'purple', linewidth=2)
    ax.axhline(0, color='black', linestyle='--', alpha=0.5)
    ax.fill_between(epochs, 0, gap, where=(gap > 0), color='red', alpha=0.3, label='Overfitting')
    ax.fill_between(epochs, 0, gap, where=(gap <= 0), color='green', alpha=0.3, label='Underfitting')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Generalization Gap (Val - Train)')
    ax.set_title('Overfitting Analysis')
    ax.legend()
    plt.tight_layout()
    plt.show()
    
    # Summary
    final_gap = gap[-1]
    if final_gap > 0.5:
        print(f"WARNING: Significant overfitting detected (gap: {final_gap:.3f})")
        print("Consider: more dropout, data augmentation, or early stopping")
    elif final_gap < -0.1:
        print(f"Model may be underfitting (gap: {final_gap:.3f})")
        print("Consider: larger model, more epochs, or lower regularization")
    else:
        print(f"Good generalization (gap: {final_gap:.3f})")

## 4. Learning Rate Analysis

In [None]:
if history:
    # Compute smoothed loss derivative to detect learning rate issues
    train_loss = np.array(history['train_loss'])
    
    # Simple moving average
    window = min(5, len(train_loss) // 3)
    if window > 1:
        smoothed = np.convolve(train_loss, np.ones(window)/window, mode='valid')
        
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.plot(range(1, len(train_loss) + 1), train_loss, 'b-', alpha=0.3, label='Raw')
        ax.plot(range(window, len(train_loss) + 1), smoothed, 'b-', linewidth=2, label='Smoothed')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Training Loss')
        ax.set_title('Training Loss (Smoothed)')
        ax.legend()
        plt.tight_layout()
        plt.show()
        
        # Check for plateaus
        recent_change = (smoothed[-1] - smoothed[-min(5, len(smoothed))]) / smoothed[-min(5, len(smoothed))]
        if abs(recent_change) < 0.01:
            print("Training has plateaued - LR reduction may have kicked in")

## 5. Confusion Matrix (if available)

In [None]:
confusion_matrix = checkpoint.get('confusion_matrix', None)

if confusion_matrix is not None:
    if isinstance(confusion_matrix, torch.Tensor):
        confusion_matrix = confusion_matrix.cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Normalize by row (true labels)
    cm_normalized = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1, keepdims=True)
    
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=range(1, 11), yticklabels=range(1, 11), ax=ax)
    ax.set_xlabel('Predicted Difficulty')
    ax.set_ylabel('True Difficulty')
    ax.set_title('Normalized Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
    # Per-class accuracy
    per_class_acc = np.diag(cm_normalized)
    print("\nPer-class accuracy:")
    for i, acc in enumerate(per_class_acc):
        bar = '#' * int(acc * 20)
        print(f"  Difficulty {i+1:2d}: {acc:.2f} {bar}")
else:
    print("No confusion matrix in checkpoint.")
    print("Run validation with confusion matrix tracking to generate one.")

## 6. Training Summary

In [None]:
print("=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)

if history:
    n_epochs = len(history['train_loss'])
    final_train_loss = history['train_loss'][-1]
    final_val_loss = history['val_loss'][-1]
    final_train_acc = history['train_acc'][-1]
    final_val_acc = history['val_acc'][-1]
    best_val_loss = min(history['val_loss'])
    best_val_acc = max(history['val_acc'])
    
    print(f"Epochs completed: {n_epochs}")
    print(f"")
    print(f"Final metrics:")
    print(f"  Train loss: {final_train_loss:.4f}")
    print(f"  Val loss:   {final_val_loss:.4f}")
    print(f"  Train acc:  {final_train_acc:.4f}")
    print(f"  Val acc:    {final_val_acc:.4f}")
    print(f"")
    print(f"Best metrics:")
    print(f"  Val loss:   {best_val_loss:.4f}")
    print(f"  Val acc:    {best_val_acc:.4f}")
    print(f"")
    
    # Recommendations
    print("Recommendations:")
    if final_val_acc < 0.3:
        print("  - Model struggling. Check data quality and class balance.")
    elif final_val_acc < 0.5:
        print("  - Moderate performance. Consider longer training or architecture changes.")
    elif final_val_acc < 0.7:
        print("  - Good progress. Fine-tune hyperparameters for better results.")
    else:
        print("  - Strong performance! Consider ensemble or test set evaluation.")
    
    if final_val_loss > final_train_loss * 1.5:
        print("  - Overfitting detected. Add regularization or data augmentation.")
else:
    print("No training history available.")

## 7. Model Architecture Info

In [None]:
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
    
    # Count parameters
    total_params = sum(p.numel() for p in state_dict.values())
    
    print(f"Total parameters: {total_params:,}")
    print(f"")
    print("Layer sizes:")
    for name, param in state_dict.items():
        if 'weight' in name:
            print(f"  {name}: {list(param.shape)}")