# Edge Case Analysis: Short Sequences & Difficult Genres

This notebook analyzes the performance of the trained music genre classification model on edge cases.
We specifically investigate:
1.  **Short Input Sequences**: How does the model perform when the audio clip is shorter than the training duration (3s)?
2.  **Difficult Genres**: Which genres are hardest to classify, and what are the characteristics of misclassified samples?
3.  **Confidence Analysis**: Can we trust the model's confidence scores?

In [None]:
import sys
import os
from pathlib import Path
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from tqdm import tqdm

# Add repo root to path
repo_root = Path.cwd().parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Constants
GENRE_NAMES = ['blues', 'classical', 'country', 'disco', 'hiphop', 
               'jazz', 'metal', 'pop', 'reggae', 'rock']
SAMPLE_RATE = 22050
TRAIN_CHUNK_DURATION = 3.0 # Model was trained on 3s chunks

In [None]:
# Import Model and Dataset Utils
# We try to import from modules, or fallback to running notebooks if modules aren't set up as packages
try:
    from model_cnn import ComplexCNN
except ModuleNotFoundError:
    print("Model module not found; loading from notebook via %run ...")
    %run "./04_model_cnn.ipynb"

try:
    from utils.datasets_gtzan import GTZANDataset, create_dataloaders
except ModuleNotFoundError:
    print("Dataset module not found; loading from notebook via %run ...")
    %run "./01_data_loading_gtzan.ipynb"

def load_trained_model(model_path, n_classes=10, device='cpu'):
    model = ComplexCNN(n_classes=n_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

In [None]:
# Load the Model
# Update this path to the run you want to analyze
run_dir = Path("../runs/20251202_114502/") 
model_path = run_dir / 'gtzan_cnn.pth'

if not model_path.exists():
    # Fallback to finding the latest run
    runs_root = Path("../runs")
    all_runs = sorted([d for d in runs_root.iterdir() if d.is_dir()])
    if all_runs:
        run_dir = all_runs[-1]
        model_path = run_dir / 'gtzan_cnn.pth'
        print(f"Specified model not found. Using latest run: {run_dir}")

print(f"Loading model from: {model_path}")
model = load_trained_model(str(model_path), device=device)

In [None]:
# Load Test Data
gtzan_root = repo_root / "data" / "gtzan"
if not gtzan_root.exists():
    print(f"Error: Dataset not found at {gtzan_root}")
else:
    # We use the same split as training to ensure we test on unseen data
    full_dataset = GTZANDataset(str(gtzan_root), cache_to_memory=False)
    
    # Create test loader with standard 3s chunks first to establish baseline
    _, _, test_loader = create_dataloaders(
        full_dataset, 
        batch_size=32, 
        chunk_length_sec=TRAIN_CHUNK_DURATION,
        test_split=0.1
    )
    print(f"Test set size: {len(test_loader.dataset)} chunks")

## 1. Analysis of Short Input Sequences

The model was trained on 3-second chunks. In real-world scenarios, we might have shorter clips.
Here we simulate shorter inputs by taking the test set audio and cropping it to shorter durations (e.g., 0.5s, 1s, 2s).
Since the model architecture (CNN) likely expects a fixed input size (corresponding to 3s), we will **pad** these shorter clips with silence to reach the 3s length.
We hypothesize that performance will drop as the signal becomes shorter.

In [None]:
def evaluate_on_short_sequences(model, dataset, durations, device):
    results = {}
    
    # We need to access the underlying file paths to re-load and crop differently
    # The test_loader.dataset is a ChunkedDataset wrapping a Subset
    # Let's access the subset directly
    if hasattr(dataset, 'dataset'):
        subset = dataset.dataset
    else:
        subset = dataset

    print(f"Evaluating on {len(subset)} test songs with varying durations...")
    
    for duration in durations:
        print(f"Testing duration: {duration}s")
        correct = 0
        total = 0
        
        # For each song in the test subset
        for i in range(len(subset)):
            waveform, label = subset[i] # This returns full 30s waveform
            
            # Take a random chunk of 'duration' length
            # Or take the first chunk to be deterministic
            # Let's take the middle chunk to avoid silence at start
            mid_point = waveform.shape[1] // 2
            target_samples = int(duration * SAMPLE_RATE)
            start = mid_point - target_samples // 2
            end = start + target_samples
            
            chunk = waveform[:, start:end]
            
            # Pad to 3s (model expected input)
            model_input_samples = int(TRAIN_CHUNK_DURATION * SAMPLE_RATE)
            if chunk.shape[1] < model_input_samples:
                padding = model_input_samples - chunk.shape[1]
                # Pad at the end
                chunk_padded = torch.nn.functional.pad(chunk, (0, padding))
            else:
                chunk_padded = chunk[:, :model_input_samples]
                
            # Add batch dim
            input_tensor = chunk_padded.unsqueeze(0).to(device)
            
            with torch.no_grad():
                output = model(input_tensor)
                pred = torch.argmax(output, dim=1).item()
            
            if pred == label:
                correct += 1
            total += 1
            
        acc = correct / total
        results[duration] = acc
        print(f"  Accuracy: {acc*100:.2f}%")
        
    return results

durations_to_test = [0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
short_seq_results = evaluate_on_short_sequences(model, test_loader.dataset, durations_to_test, device)

# Plot
plt.figure(figsize=(10, 6))
plt.plot(list(short_seq_results.keys()), [v*100 for v in short_seq_results.values()], marker='o')
plt.title("Model Accuracy vs Input Duration")
plt.xlabel("Input Duration (seconds)")
plt.ylabel("Accuracy (%)")
plt.grid(True)
plt.axhline(y=short_seq_results[3.0]*100, color='r', linestyle='--', label='Baseline (3s)')
plt.legend()
plt.show()

## 2. Analysis of Difficult Genres

We analyze the confusion matrix to identify which genres are most frequently misclassified.

In [None]:
def get_predictions(model, loader, device):
    all_preds = []
    all_labels = []
    all_probs = []
    
    model.eval()
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())
            
    return np.array(all_preds), np.array(all_labels), np.array(all_probs)

preds, labels, probs = get_predictions(model, test_loader, device)

# Classification Report
print(classification_report(labels, preds, target_names=GENRE_NAMES))

# Confusion Matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=GENRE_NAMES, yticklabels=GENRE_NAMES)
plt.title("Confusion Matrix")
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Identify lowest performing genres
class_acc = cm.diagonal() / cm.sum(axis=1)
worst_genres_idx = np.argsort(class_acc)[:3]
print("\nTop 3 Most Difficult Genres:")
for idx in worst_genres_idx:
    print(f"{GENRE_NAMES[idx]}: {class_acc[idx]*100:.2f}% accuracy")

## 3. Confidence Analysis

Do misclassified examples have lower confidence scores?

In [None]:
# Extract confidence for correct vs incorrect predictions
confidences = np.max(probs, axis=1)
correct_mask = preds == labels
incorrect_mask = ~correct_mask

plt.figure(figsize=(10, 6))
sns.histplot(confidences[correct_mask], color='green', label='Correct', kde=True, stat="density", alpha=0.5)
sns.histplot(confidences[incorrect_mask], color='red', label='Incorrect', kde=True, stat="density", alpha=0.5)
plt.title("Confidence Distribution: Correct vs Incorrect Predictions")
plt.xlabel("Confidence Score")
plt.legend()
plt.show()

print(f"Average Confidence (Correct): {np.mean(confidences[correct_mask]):.4f}")
print(f"Average Confidence (Incorrect): {np.mean(confidences[incorrect_mask]):.4f}")