# CNNv2: Advanced Multi-Instrument Note Analysis Pipeline

This notebook implements the advanced architecture from `architecture.md` with:
- Instrument embeddings for handling multiple instruments
- Temporal processing with duration-aware attention
- Variable-length spectrogram handling
- Class balancing for imbalanced labels
- GPU-optimized training with mixed precision

## 🚨 IMPORTANT: GPU Training Setup

**For GPU Training (recommended):**
1. Set `TEST_MODE = False` in the config cell
2. Ensure you have CUDA-compatible GPU with 8GB+ memory
3. Expect ~30-60 minutes training time on modern GPU

**For Testing/Development:**
1. Set `TEST_MODE = True` (default)
2. Safe to run on CPU/Mac for validation
3. Only 2 epochs with small batches for quick testing

## 📋 Before You Start:
- Check that the database path is correct for your system
- Verify GPU is detected in the config cell output
- Run the data loading test cell first to catch any issues
- The advanced model has ~2-3M parameters and uses temporal processing

## 🎯 Expected Results:
- Target accuracy: 75-85% (from architecture.md)
- The model handles all 12 instruments and ~100 quality classes
- Duration-aware processing for better note analysis

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path().resolve().parent))

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
from tqdm import tqdm

from utils.notes_processing import generate_spectrogram
from models.dataset import GoodSoundsDatabase, GoodSoundsDataset

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="scipy.io.wavfile")
warnings.filterwarnings("ignore", message="Chunk \\(non-data\\) not understood, skipping it.")

import logging
logging.getLogger("madmom").setLevel(logging.ERROR)

In [None]:
# Configuration with test mode for development
TEST_MODE = True  # Set to False for full GPU training

if TEST_MODE:
    print("🧪 TEST MODE ENABLED - Safe for CPU/Mac testing")
    CONFIG = {
        'database_path': '/Users/dhanush/documents/musaic/good-sounds',
        'batch_size': 4,  # Very small for testing
        'num_workers': 0,  # IMPORTANT: Set to 0 for Mac/notebook compatibility
        'learning_rate': 0.001,
        'max_lr': 0.01,
        'epochs': 2,  # Just 2 epochs for testing
        'test_size': 0.2,
        'val_size': 0.2,
        'random_state': 42,
        'n_fft': 512,
        'hop_length': 128,
        'sr': 22050
    }
    print("  - Small batch size (4)")
    print("  - Only 2 epochs")
    print("  - num_workers=0 (Mac compatibility)")
    print("  - Will work on CPU")
else:
    print("🚀 FULL TRAINING MODE - GPU optimized")
    CONFIG = {
        'database_path': '/Users/dhanush/documents/musaic/good-sounds',
        'batch_size': 64,  # Large batch for GPU
        'num_workers': 8,  # More workers for GPU
        'learning_rate': 0.001,
        'max_lr': 0.01,
        'epochs': 50,  # Full training
        'test_size': 0.2,
        'val_size': 0.2,
        'random_state': 42,
        'n_fft': 512,
        'hop_length': 128,
        'sr': 22050
    }

# Device setup with fallbacks
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    # Enable GPU optimizations only if CUDA available
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print("GPU optimizations enabled")
else:
    print("No GPU detected - using CPU mode")
    if not TEST_MODE:
        print("⚠️  WARNING: Full training on CPU will be very slow!")
        print("   Consider enabling TEST_MODE for development")

## Data Loading and Preparation

In [None]:
# Load database and get all instruments
db = GoodSoundsDatabase(CONFIG['database_path'])
all_df = db.get_all_sounds()

print(f"Total samples: {len(all_df)}")
print(f"\nInstruments:")
instrument_counts = all_df['instrument'].value_counts()
for instrument, count in instrument_counts.items():
    print(f"  {instrument}: {count} samples")

print(f"\nTop 20 labels:")
label_counts = all_df['klass'].value_counts().head(20)
for label, count in label_counts.items():
    print(f"  {label}: {count} samples")

In [None]:
# Create instrument and label encoders
instrument_encoder = LabelEncoder()
instrument_encoder.fit(all_df['instrument'])

label_encoder = LabelEncoder()
label_encoder.fit(all_df['klass'])

print(f"Number of instruments: {len(instrument_encoder.classes_)}")
print(f"Number of classes: {len(label_encoder.classes_)}")

print(f"\nInstrument mapping:")
for i, instrument in enumerate(instrument_encoder.classes_):
    print(f"  {i}: {instrument}")

print(f"\nSample label mapping (first 20):")
for i, label in enumerate(label_encoder.classes_[:20]):
    print(f"  {i}: {label}")

## Advanced Dataset with Augmentation

In [None]:
class MultiInstrumentDataset(Dataset):
    """Wrapper around GoodSoundsDataset that adds instrument information"""
    
    def __init__(self, dataframe, instrument_encoder, label_encoder, 
                 spectrogram_function, cache_spectrograms=False, cache_dir=None):
        self.df = dataframe.reset_index(drop=True)
        self.instrument_encoder = instrument_encoder
        
        # Create the underlying GoodSoundsDataset
        self.base_dataset = GoodSoundsDataset(
            dataframe, spectrogram_function, label_encoder,
            cache_spectrograms=cache_spectrograms, cache_dir=cache_dir
        )
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        # Get the base item from GoodSoundsDataset
        item = self.base_dataset[idx]
        
        # Add instrument information from dataframe
        row = self.df.iloc[idx]
        instrument_id = self.instrument_encoder.transform([row['instrument']])[0]
        
        # Calculate duration from spectrogram (approximate)
        spectrogram = item['spectrogram']
        # Assuming hop_length=128 and sr=22050 from madmom defaults
        duration = spectrogram.shape[-1] * 128 / 22050  # time_frames * hop_length / sample_rate
        
        return {
            'spectrogram': item['spectrogram'],
            'instrument_id': torch.tensor(instrument_id, dtype=torch.long),
            'duration': torch.tensor(duration, dtype=torch.float),
            'label': item['label'],
            'file_path': item['file_path'],
            'original_label': item['original_label']
        }

In [None]:
def collate_fn(batch):
    """Custom collate function for variable-length spectrograms with instrument info"""
    spectrograms = [item['spectrogram'] for item in batch]
    instrument_ids = torch.stack([item['instrument_id'] for item in batch])
    durations = torch.stack([item['duration'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    
    # Find max dimensions in batch (both freq and time can vary with madmom)
    max_freq = max(spec.shape[-2] for spec in spectrograms)  # frequency dimension
    max_time = max(spec.shape[-1] for spec in spectrograms)  # time dimension
    
    # Pad all spectrograms to same dimensions
    padded_specs = []
    for spec in spectrograms:
        # Ensure spec is a tensor
        if not isinstance(spec, torch.Tensor):
            spec = torch.from_numpy(spec).float()
        
        # Add channel dimension if not present
        if spec.dim() == 2:  # (freq, time)
            spec = spec.unsqueeze(0)  # Add channel -> (1, freq, time)
        elif spec.dim() == 3 and spec.shape[0] != 1:  # Wrong channel format
            # If it's (freq, time, 1) or similar, reshape to (1, freq, time)
            if spec.shape[2] == 1:
                spec = spec.squeeze(2).unsqueeze(0)
            else:
                spec = spec.unsqueeze(0)
        
        # Current dimensions
        current_freq = spec.shape[-2]
        current_time = spec.shape[-1]
        
        # Pad frequency dimension (bottom padding)
        if current_freq < max_freq:
            freq_padding = max_freq - current_freq
            spec = F.pad(spec, (0, 0, 0, freq_padding), value=0)  # (left, right, top, bottom)
        elif current_freq > max_freq:
            # Truncate frequency if somehow larger
            spec = spec[:, :max_freq, :]
        
        # Pad time dimension (right padding)
        if current_time < max_time:
            time_padding = max_time - current_time
            spec = F.pad(spec, (0, time_padding), value=0)  # (left, right)
        elif current_time > max_time:
            # Truncate time if somehow larger
            spec = spec[:, :, :max_time]
        
        padded_specs.append(spec)
    
    # Stack all padded spectrograms
    try:
        stacked_specs = torch.stack(padded_specs)
    except Exception as e:
        print(f"Error stacking spectrograms after padding:")
        for i, spec in enumerate(padded_specs):
            print(f"  Spec {i}: {spec.shape}")
        print(f"Target shape should be: [batch_size, 1, {max_freq}, {max_time}]")
        raise e
    
    return {
        'spectrogram': stacked_specs,
        'instrument_id': instrument_ids,
        'duration': durations,
        'label': labels
    }

## SimpleStarterModel Architecture

In [None]:
class AdvancedNoteAnalyzer(nn.Module):
    """Complete architecture from architecture.md with temporal processing"""
    
    def __init__(self, num_instruments=12, num_classes=100):
        super().__init__()
        
        # 1. Instrument Embedding Layer
        # Maps instrument ID to learned 64-dim representation
        self.instrument_embedding = nn.Embedding(num_instruments, 64)
        
        # 2. CNN Feature Extractor
        # Processes variable-length spectrograms
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, None))  # Pool frequency only, preserve time
        )      
        # 3. Temporal Processing Layer
        # Aggregates features across time (Conv1d faster than GRU on GPU)
        self.temporal = nn.Conv1d(256*8, 512, kernel_size=3, padding=1)
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        
        # 4. Fusion Layer
        # Combines audio features with instrument embedding
        self.fusion = nn.Sequential(
            nn.Linear(512 + 64, 512),  # features + instrument
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # 5. Classification Head
        self.classifier = nn.Linear(512, num_classes)
    
    def forward(self, spectrogram, instrument_id, duration=None):
        batch_size = spectrogram.size(0)
        
        # 1. Extract CNN features
        cnn_features = self.cnn(spectrogram)  # (batch, 256, 8, time)
        
        # 2. Reshape for temporal processing
        # Flatten spatial dimensions for Conv1d
        cnn_features = cnn_features.view(batch_size, 256*8, -1)  # (batch, 2048, time)
        
        # 3. Temporal processing
        temporal_features = self.temporal(cnn_features)  # (batch, 512, time)
        
        # Apply duration-aware processing for better note analysis
        if duration is not None:
            # For short notes (< 0.5s), focus on attack (first frames)
            # For long notes, use full temporal context
            pooled_features = []
            for i, dur in enumerate(duration):
                if dur < 0.5:  # Short note - focus on attack
                    # Take first 20% of frames or minimum 5 frames
                    attack_frames = max(5, int(temporal_features.shape[-1] * 0.2))
                    attack_features = temporal_features[i:i+1, :, :attack_frames]
                    pooled = self.temporal_pool(attack_features).squeeze(-1)
                else:  # Long note - use full context
                    pooled = self.temporal_pool(temporal_features[i:i+1]).squeeze(-1)
                pooled_features.append(pooled)
            
            pooled_temporal = torch.cat(pooled_features, dim=0)  # (batch, 512)
        else:
            # Standard global pooling
            pooled_temporal = self.temporal_pool(temporal_features).squeeze(-1)  # (batch, 512)
        
        # 4. Get instrument embedding
        inst_emb = self.instrument_embedding(instrument_id)  # (batch, 64)
        
        # 5. Fusion and classification
        combined = torch.cat([pooled_temporal, inst_emb], dim=1)  # (batch, 576)
        fused = self.fusion(combined)  # (batch, 512)
        output = self.classifier(fused)  # (batch, num_classes)
        
        return output


class SimpleStarterModel(nn.Module):
    """Simpler version for comparison - trains faster"""
    
    def __init__(self, num_instruments=12, num_classes=100):
        super().__init__()
        
        # Instrument embedding layer
        self.instrument_embedding = nn.Embedding(num_instruments, 32)
        
        # CNN feature extractor
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)  # Global pooling
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(128 + 32, 256),  # CNN features + instrument embedding
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, spectrogram, instrument_id, duration=None):
        # Extract CNN features
        cnn_features = self.cnn(spectrogram)  # (batch, 128, 1, 1)
        cnn_features = cnn_features.squeeze(-1).squeeze(-1)  # (batch, 128)
        
        # Get instrument embedding
        inst_emb = self.instrument_embedding(instrument_id)  # (batch, 32)
        
        # Combine features
        combined = torch.cat([cnn_features, inst_emb], dim=1)  # (batch, 160)
        
        # Classify
        output = self.classifier(combined)
        
        return output

## Data Splitting and Preparation

In [None]:
# Stratified split by class (labels)
# Filter out classes with too few samples for stratification
min_samples = 3  # Need at least 3 samples per class for train/val/test split
class_counts = all_df['klass'].value_counts()
valid_classes = class_counts[class_counts >= min_samples].index
filtered_df = all_df[all_df['klass'].isin(valid_classes)].copy()

print(f"Filtered from {len(all_df)} to {len(filtered_df)} samples")
print(f"Removed {len(all_df) - len(filtered_df)} samples with insufficient class representation")

# Split data
train_val_df, test_df = train_test_split(
    filtered_df,
    test_size=CONFIG['test_size'],
    random_state=CONFIG['random_state'],
    stratify=filtered_df['klass']
)

train_df, val_df = train_test_split(
    train_val_df,
    test_size=CONFIG['val_size'],
    random_state=CONFIG['random_state'],
    stratify=train_val_df['klass']
)

print(f"\nData split:")
print(f"Training: {len(train_df)} ({len(train_df)/len(filtered_df)*100:.1f}%)")
print(f"Validation: {len(val_df)} ({len(val_df)/len(filtered_df)*100:.1f}%)")
print(f"Testing: {len(test_df)} ({len(test_df)/len(filtered_df)*100:.1f}%)")

In [None]:
# Use MultiInstrumentDataset wrapper with generate_spectrogram function
spectrogram_function = generate_spectrogram
cache_spectrograms = False  # Disable caching for now (can enable if needed)
cache_dir = None

# Create datasets using MultiInstrumentDataset wrapper
train_dataset = MultiInstrumentDataset(
    train_df, instrument_encoder, label_encoder, spectrogram_function,
    cache_spectrograms=cache_spectrograms, cache_dir=cache_dir
)

val_dataset = MultiInstrumentDataset(
    val_df, instrument_encoder, label_encoder, spectrogram_function,
    cache_spectrograms=cache_spectrograms, cache_dir=cache_dir
)

test_dataset = MultiInstrumentDataset(
    test_df, instrument_encoder, label_encoder, spectrogram_function,
    cache_spectrograms=cache_spectrograms, cache_dir=cache_dir
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn,
    persistent_workers=False if CONFIG['num_workers'] == 0 else True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'] * 2,  # Larger batch for validation
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'] * 2,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn
)

print(f"Created data loaders:")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## Class Weight Calculation for Imbalanced Data

In [None]:
# Calculate class weights for imbalanced data
y_train = train_df['klass'].apply(lambda x: label_encoder.transform([x])[0])
unique_classes_in_train = np.unique(y_train)
all_classes = np.arange(len(label_encoder.classes_))

print(f"Total classes in label encoder: {len(label_encoder.classes_)}")
print(f"Classes present in training data: {len(unique_classes_in_train)}")

# Calculate weights only for classes present in training data
class_weights_partial = compute_class_weight(
    'balanced',
    classes=unique_classes_in_train,
    y=y_train
)

# Create full weight tensor with 1.0 for missing classes
class_weights_full = np.ones(len(label_encoder.classes_))
class_weights_full[unique_classes_in_train] = class_weights_partial

class_weights = torch.FloatTensor(class_weights_full).to(device)

print(f"Class weights shape: {class_weights.shape}")
print(f"Weight range: {class_weights.min():.3f} to {class_weights.max():.3f}")
print(f"Classes with custom weights: {len(unique_classes_in_train)}")
print(f"Classes with default weight (1.0): {len(all_classes) - len(unique_classes_in_train)}")

## Test Data Loading

In [None]:
# Test data loading with safety checks
print("Testing data loading...")

try:
    # Test a single batch
    batch = next(iter(train_loader))
    
    print(f"✅ Data loading successful!")
    print(f"Batch shapes:")
    print(f"  Spectrogram: {batch['spectrogram'].shape}")
    print(f"  Instrument ID: {batch['instrument_id'].shape}")
    print(f"  Duration: {batch['duration'].shape}")
    print(f"  Label: {batch['label'].shape}")
    
    print(f"\\nSample values:")
    print(f"  Instruments: {batch['instrument_id'][:min(5, len(batch['instrument_id']))]}")
    print(f"  Labels: {batch['label'][:min(5, len(batch['label']))]}")
    print(f"  Durations: {batch['duration'][:min(5, len(batch['duration']))].numpy()}")
    print(f"  Spectrogram range: [{batch['spectrogram'].min():.3f}, {batch['spectrogram'].max():.3f}]")
    
    print(f"\\n🎉 Data loading test passed! Model test will happen after model initialization.")
    
except Exception as e:
    print(f"❌ Error in data loading test: {e}")
    print(f"\\nDebugging info:")
    print(f"  Train dataset length: {len(train_dataset)}")
    print(f"  Train loader length: {len(train_loader)}")
    print(f"  Device: {device}")
    print(f"  Test mode: {TEST_MODE}")
    
    if TEST_MODE:
        print(f"\\n💡 This is expected in test mode - continue anyway")
    else:
        print(f"\\n🚨 Fix these errors before full training!")
        raise

## Model Initialization

In [None]:
# Initialize the advanced model
num_instruments = len(instrument_encoder.classes_)
num_classes = len(label_encoder.classes_)

# Use the full AdvancedNoteAnalyzer model
model = AdvancedNoteAnalyzer(num_instruments=num_instruments, num_classes=num_classes)

# Alternatively, use SimpleStarterModel for faster training:
# model = SimpleStarterModel(num_instruments=num_instruments, num_classes=num_classes)

print(f"Model initialized: AdvancedNoteAnalyzer")
print(f"  Number of instruments: {num_instruments}")
print(f"  Number of classes: {num_classes}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Move model to GPU if available
model = model.to(device)
print(f"  Model moved to: {device}")

# Test model forward pass with a sample batch
print(f"\\nTesting model forward pass...")
try:
    model.eval()
    batch = next(iter(train_loader))
    
    with torch.no_grad():
        test_spectrograms = batch['spectrogram'][:2].to(device)  # Take only 2 samples
        test_instrument_ids = batch['instrument_id'][:2].to(device)
        test_durations = batch['duration'][:2].to(device)
        
        outputs = model(test_spectrograms, test_instrument_ids, test_durations)
        print(f"✅ Model forward pass successful!")
        print(f"  Input shape: {test_spectrograms.shape}")
        print(f"  Output shape: {outputs.shape}")
        print(f"  Output range: [{outputs.min():.3f}, {outputs.max():.3f}]")
    
    # Memory cleanup
    if device.type == 'cuda':
        torch.cuda.empty_cache()
        
    print(f"\\n🎉 Model test passed! Ready for training.")
    
except Exception as e:
    print(f"❌ Error in model forward pass: {e}")
    if TEST_MODE:
        print(f"💡 Continuing in test mode...")
    else:
        raise

## Training Function

In [None]:
def train_model_gpu(model, train_loader, val_loader, config, class_weights):
    """GPU-optimized training loop with CPU fallback and safety checks"""
    
    model = model.to(device)
    
    # Mixed precision only if CUDA is available
    use_amp = device.type == 'cuda'
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    
    # Safety check for batch size vs dataset size
    if len(train_loader) < 2:
        print("⚠️  WARNING: Very few training batches - consider smaller batch size")
    
    # Optimizer and scheduler with safety checks
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config['learning_rate'],
        weight_decay=1e-4 if not TEST_MODE else 1e-5,  # Lighter regularization in test mode
        eps=1e-8
    )
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config['max_lr'],
        epochs=config['epochs'],
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        div_factor=10,
        final_div_factor=100
    )
    
    # Loss function with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    best_val_acc = 0
    history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'lr': []}
    
    print(f"Starting training:")
    print(f"  Mixed precision: {use_amp}")
    print(f"  Device: {device}")
    print(f"  Epochs: {config['epochs']}")
    print(f"  Batch size: {config['batch_size']}")
    print(f"  Steps per epoch: {len(train_loader)}")
    print(f"  Test mode: {TEST_MODE}")
    
    for epoch in range(config['epochs']):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        # Use simpler progress bar in test mode
        if TEST_MODE:
            train_iter = enumerate(train_loader)
            print(f"Epoch {epoch+1}/{config['epochs']} - Training...")
        else:
            train_iter = enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]} - Training'))
        
        for batch_idx, batch in train_iter:
            try:
                spectrograms = batch['spectrogram'].to(device, non_blocking=True)
                instrument_ids = batch['instrument_id'].to(device, non_blocking=True)
                durations = batch['duration'].to(device, non_blocking=True)
                labels = batch['label'].to(device, non_blocking=True)
                
                optimizer.zero_grad(set_to_none=True)
                
                # Mixed precision forward pass (only if GPU available)
                if use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = model(spectrograms, instrument_ids, durations)
                        loss = criterion(outputs, labels)
                    
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    outputs = model(spectrograms, instrument_ids, durations)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                
                scheduler.step()
                
                # Statistics
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                train_total += labels.size(0)
                train_correct += predicted.eq(labels).sum().item()
                
                # Progress updates
                if TEST_MODE and batch_idx % max(1, len(train_loader)//4) == 0:
                    train_acc = 100. * train_correct / train_total if train_total > 0 else 0
                    print(f"  Batch {batch_idx+1}/{len(train_loader)}: Loss={loss.item():.4f}, Acc={train_acc:.1f}%")
                
                # Memory management (only for GPU)
                if batch_idx % 50 == 0 and device.type == 'cuda':
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"Error in training batch {batch_idx}: {e}")
                if TEST_MODE:
                    print("Continuing in test mode...")
                    continue
                else:
                    raise
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        if TEST_MODE:
            print(f"Epoch {epoch+1}/{config['epochs']} - Validation...")
            val_iter = enumerate(val_loader)
        else:
            val_iter = enumerate(tqdm(val_loader, desc=f'Epoch {epoch+1}/{config["epochs"]} - Validation'))
        
        with torch.no_grad():
            for batch_idx, batch in val_iter:
                try:
                    spectrograms = batch['spectrogram'].to(device, non_blocking=True)
                    instrument_ids = batch['instrument_id'].to(device, non_blocking=True)
                    durations = batch['duration'].to(device, non_blocking=True)
                    labels = batch['label'].to(device, non_blocking=True)
                    
                    if use_amp:
                        with torch.cuda.amp.autocast():
                            outputs = model(spectrograms, instrument_ids, durations)
                            loss = criterion(outputs, labels)
                    else:
                        outputs = model(spectrograms, instrument_ids, durations)
                        loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()
                    
                except Exception as e:
                    print(f"Error in validation batch {batch_idx}: {e}")
                    if TEST_MODE:
                        continue
                    else:
                        raise
        
        # Calculate epoch metrics
        epoch_train_loss = train_loss / len(train_loader) if len(train_loader) > 0 else 0
        epoch_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
        epoch_val_acc = 100. * val_correct / val_total if val_total > 0 else 0
        epoch_train_acc = 100. * train_correct / train_total if train_total > 0 else 0
        current_lr = scheduler.get_last_lr()[0]
        
        # Save history
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)
        history['lr'].append(current_lr)
        
        print(f'\\nEpoch {epoch+1}/{config["epochs"]} Summary:')
        print(f'  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%')
        print(f'  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%')
        print(f'  Learning Rate: {current_lr:.6f}')
        
        # GPU memory info (only if GPU available)
        if device.type == 'cuda':
            memory_allocated = torch.cuda.memory_allocated(device) / 1e9
            memory_reserved = torch.cuda.memory_reserved(device) / 1e9
            print(f'  GPU Memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved')
        
        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': best_val_acc,
                'config': config,
                'scaler_state_dict': scaler.state_dict() if scaler else None,
                'test_mode': TEST_MODE
            }
            save_name = 'best_test_model.pth' if TEST_MODE else 'best_advanced_model.pth'
            torch.save(checkpoint, save_name)
            print(f'  ✓ New best model saved as {save_name}! Val accuracy: {best_val_acc:.2f}%')
        
        # Early stopping (more aggressive in test mode)
        early_stop_patience = 3 if TEST_MODE else 5
        if epoch > early_stop_patience and len(history['val_loss']) > early_stop_patience:
            recent_losses = history['val_loss'][-early_stop_patience:]
            if all(recent_losses[i] >= recent_losses[i+1] - 0.001 for i in range(len(recent_losses)-1)):
                print(f'  Early stopping triggered - validation loss plateaued')
                break
        
        print('-' * 60)
        
        # Cleanup GPU memory
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    return model, history

## Start Training

In [None]:
# Start training with mode-appropriate settings
if TEST_MODE:
    print("🧪 Starting TEST training (safe for CPU/Mac)...")
    print(f"  - This will run quickly on your Mac")
    print(f"  - Only {CONFIG['epochs']} epochs with small batches")
    print(f"  - Perfect for testing the pipeline")
    print(f"  - Change TEST_MODE=False for full GPU training")
else:
    print("🚀 Starting FULL GPU training...")
    print(f"  - Make sure you have a powerful GPU!")
    print(f"  - {CONFIG['epochs']} epochs with large batches") 
    print(f"  - Expected time: ~{CONFIG['epochs'] * len(train_loader) / 100:.1f} minutes on modern GPU")
    print(f"  - Will save model as 'best_advanced_model.pth'")

print(f"\\nTraining configuration:")
print(f"  Model: AdvancedNoteAnalyzer with temporal processing")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Device: {device}")
print(f"  Mixed precision: {device.type == 'cuda'}")

# Confirm before starting full training
if not TEST_MODE and device.type != 'cuda':
    response = input("\\n⚠️  WARNING: Full training on CPU will be VERY slow. Continue? (y/N): ")
    if response.lower() != 'y':
        print("Training cancelled. Consider setting TEST_MODE=True for development.")
        raise KeyboardInterrupt("Training cancelled by user")

print(f"\\n{'='*60}")
print(f"STARTING TRAINING")
print(f"{'='*60}")

trained_model, history = train_model_gpu(
    model, train_loader, val_loader, CONFIG, class_weights
)

## Training Results Analysis

In [None]:
# Plot comprehensive training history
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', alpha=0.8)
ax1.plot(history['val_loss'], label='Validation Loss', alpha=0.8)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot(history['val_acc'], label='Validation Accuracy', color='green', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning rate plot
ax3.plot(history['lr'], label='Learning Rate', color='orange')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_title('Learning Rate Schedule (OneCycleLR)')
ax3.set_yscale('log')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Loss vs Accuracy correlation
ax4.scatter(history['val_loss'], history['val_acc'], alpha=0.6, c=range(len(history['val_acc'])), cmap='viridis') 
ax4.set_xlabel('Validation Loss')
ax4.set_ylabel('Validation Accuracy (%)')
ax4.set_title('Loss vs Accuracy (colored by epoch)')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Training Summary:")
print(f"  Best validation accuracy: {max(history['val_acc']):.2f}%")
print(f"  Final validation accuracy: {history['val_acc'][-1]:.2f}%")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final validation loss: {history['val_loss'][-1]:.4f}")
print(f"  Total epochs completed: {len(history['val_acc'])}")

# Memory cleanup
if device.type == 'cuda':
    torch.cuda.empty_cache()
    print(f"  GPU memory cleared")

In [None]:
# Load best model and evaluate on test set with GPU optimizations
checkpoint = torch.load('best_advanced_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

test_correct = 0
test_total = 0
all_predictions = []
all_labels = []
all_confidences = []

print("Evaluating on test set...")

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        spectrograms = batch['spectrogram'].to(device, non_blocking=True)
        instrument_ids = batch['instrument_id'].to(device, non_blocking=True)
        durations = batch['duration'].to(device, non_blocking=True)
        labels = batch['label'].to(device, non_blocking=True)
        
        # Use mixed precision for inference too
        if device.type == 'cuda':
            with torch.cuda.amp.autocast():
                outputs = model(spectrograms, instrument_ids, durations)
        else:
            outputs = model(spectrograms, instrument_ids, durations)
        
        # Get predictions and confidence scores
        probabilities = torch.softmax(outputs, dim=1)
        confidences, predicted = probabilities.max(1)
        
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()
        
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_confidences.extend(confidences.cpu().numpy())

test_accuracy = 100. * test_correct / test_total
mean_confidence = np.mean(all_confidences)

print(f"\\nFinal Test Results:")
print(f"  Test Accuracy: {test_accuracy:.2f}%")
print(f"  Test samples: {test_total}")
print(f"  Correct predictions: {test_correct}")
print(f"  Mean confidence: {mean_confidence:.3f}")

# Per-class analysis
from collections import Counter
from sklearn.metrics import classification_report, confusion_matrix

# Get class names
class_names = label_encoder.classes_

print(f"\\nDetailed Classification Report:")
print(classification_report(all_labels, all_predictions, 
                          target_names=class_names, 
                          zero_division=0))

# Top-5 and bottom-5 performing classes
correct_per_class = np.zeros(len(class_names))
total_per_class = np.zeros(len(class_names))

for true_label, pred_label in zip(all_labels, all_predictions):
    total_per_class[true_label] += 1
    if true_label == pred_label:
        correct_per_class[true_label] += 1

# Avoid division by zero
class_accuracies = np.divide(correct_per_class, total_per_class, 
                           out=np.zeros_like(correct_per_class), 
                           where=total_per_class!=0) * 100

# Sort classes by accuracy
sorted_indices = np.argsort(class_accuracies)

print(f"\\nTop 5 performing classes:")
for i in sorted_indices[-5:]:
    if total_per_class[i] > 0:
        print(f"  {class_names[i]}: {class_accuracies[i]:.1f}% ({int(correct_per_class[i])}/{int(total_per_class[i])})")

print(f"\\nBottom 5 performing classes:")
for i in sorted_indices[:5]:
    if total_per_class[i] > 0:
        print(f"  {class_names[i]}: {class_accuracies[i]:.1f}% ({int(correct_per_class[i])}/{int(total_per_class[i])})")

# Memory cleanup
if device.type == 'cuda':
    torch.cuda.empty_cache()

In [None]:
# Save final model with metadata
torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'num_instruments': num_instruments,
    'num_classes': num_classes,
    'instrument_encoder': instrument_encoder,
    'label_encoder': label_encoder,
    'test_accuracy': test_accuracy,
    'history': history
}, 'cnnv2_final_model.pth')

print("Final model saved as 'cnnv2_final_model.pth'")
print(f"\nModel Summary:")
print(f"  Architecture: SimpleStarterModel with instrument embeddings")
print(f"  Instruments: {num_instruments}")
print(f"  Classes: {num_classes}")
print(f"  Test accuracy: {test_accuracy:.2f}%")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")