# 🎹 Piano Performance Analysis - Colab Validation Test

## Objective
Test our JAX/Flax piano CNN architectures with synthetic data before collecting real Chopin/Liszt recordings.

## What We're Testing
1. **Three CNN Architectures**: Standard, Multi-Spectral Fusion, Real-time
2. **Synthetic Data Pipeline**: Realistic piano-like spectrograms 
3. **Training Convergence**: End-to-end training with proper metrics
4. **Performance Comparison**: Which architecture works best

## Success Criteria
- Models train without errors
- Loss decreases consistently 
- Correlations with synthetic labels > 0.7
- Training completes in reasonable time

# 📦 Environment Setup

In [None]:
# Install uv package manager first
!curl -LsSf https://astral.sh/uv/install.sh | sh
import os
os.environ['PATH'] = f"/root/.cargo/bin:{os.environ.get('PATH', '')}"

# Verify uv installation
!uv --version

# Install dependencies using uv (faster than pip)
!uv pip install --system jax[gpu] flax optax
!uv pip install --system librosa soundfile matplotlib seaborn
!uv pip install --system wandb tqdm pandas numpy scipy

# Import core libraries
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from typing import Dict, Tuple, List
from dataclasses import dataclass
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

print(f"🚀 JAX version: {jax.__version__}")
print(f"🔧 Available devices: {jax.devices()}")
print(f"💻 Platform: {jax.lib.xla_bridge.get_backend().platform}")

# 🧠 Model Architectures

In [None]:
# Core model architectures adapted from your codebase
from flax.training import train_state, checkpoints
from flax.training.early_stopping import EarlyStopping
import functools

class SpectralConvBlock(nn.Module):
    """Optimized conv block for mel-spectrogram processing"""
    features: int
    kernel_size: tuple = (3, 3)
    strides: tuple = (1, 1)
    dropout_rate: float = 0.1
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        x = nn.Conv(
            features=self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding='SAME'
        )(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
        return x

class PianoSpectroCNN(nn.Module):
    """Standard piano CNN - VGGish inspired"""
    num_classes: int = 19
    base_filters: int = 64
    dropout_rate: float = 0.2
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        # Input: (batch, time, freq, 1) - mel-spectrograms
        
        # Feature extraction layers
        x = SpectralConvBlock(self.base_filters, kernel_size=(3, 3))(x, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = SpectralConvBlock(self.base_filters * 2, kernel_size=(3, 3))(x, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = SpectralConvBlock(self.base_filters * 4, kernel_size=(3, 3))(x, training)
        x = SpectralConvBlock(self.base_filters * 4, kernel_size=(3, 3))(x, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = SpectralConvBlock(self.base_filters * 8, kernel_size=(3, 3))(x, training)
        x = SpectralConvBlock(self.base_filters * 8, kernel_size=(3, 3))(x, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        # Global pooling and classification
        x = nn.avg_pool(x, window_shape=(x.shape[1], x.shape[2]))
        x = jnp.reshape(x, (x.shape[0], -1))
        
        # Multi-task prediction heads
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
        
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
        
        x = nn.Dense(self.num_classes)(x)
        return nn.sigmoid(x)

class RealTimePianoCNN(nn.Module):
    """Lightweight CNN for real-time inference"""
    num_classes: int = 19
    width_multiplier: float = 0.5
    
    @nn.compact  
    def __call__(self, x, training: bool = True):
        base_filters = int(32 * self.width_multiplier)
        
        # Efficient separable convolutions
        x = self._separable_conv_block(x, base_filters, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = self._separable_conv_block(x, base_filters * 2, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        
        x = self._separable_conv_block(x, base_filters * 4, training)
        x = nn.avg_pool(x, window_shape=(x.shape[1], x.shape[2]))
        
        x = jnp.reshape(x, (x.shape[0], -1))
        
        # Compact classifier
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)
        return nn.sigmoid(x)
    
    def _separable_conv_block(self, x, filters, training):
        # Depth-wise convolution  
        x = nn.Conv(
            features=x.shape[-1],
            kernel_size=(3, 3),
            feature_group_count=x.shape[-1],
            padding='SAME'
        )(x)
        
        # Point-wise convolution
        x = nn.Conv(features=filters, kernel_size=(1, 1))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        return x

print("🧠 Model architectures loaded successfully!")

# 🎼 Synthetic Data Generation

In [None]:
class SyntheticPianoDataGenerator:
    """Generate realistic synthetic piano spectrograms with correlated perceptual labels"""
    
    def __init__(self, 
                 num_samples: int = 1000,
                 time_frames: int = 128,
                 freq_bins: int = 128,
                 seed: int = 42):
        self.num_samples = num_samples
        self.time_frames = time_frames
        self.freq_bins = freq_bins
        self.rng = np.random.RandomState(seed)
        
        # Define perceptual dimensions
        self.dimensions = [
            "Timing_Stable_Unstable", "Articulation_Short_Long", 
            "Articulation_Soft_Hard", "Pedal_Dry_Wet", "Pedal_Clean_Blurred",
            "Timbre_Even_Colorful", "Timbre_Shallow_Rich", "Timbre_Bright_Dark",
            "Timbre_Soft_Loud", "Dynamic_Mellow_Raw", "Dynamic_Small_Large_Range",
            "Music_Fast_Slow", "Music_Flat_Spacious", "Music_Unbalanced_Balanced",
            "Music_Pure_Expressive", "Mood_Pleasant_Dark", "Mood_Low_High_Energy",
            "Mood_Honest_Imaginative", "Interpretation_Poor_Convincing"
        ]
        
    def generate_piano_spectrogram(self, style_params: Dict) -> jnp.ndarray:
        """Generate realistic piano mel-spectrogram"""
        
        # Base piano harmonics (fundamental + overtones)
        base_freqs = np.array([88, 176, 264, 352, 440, 528, 660])  # Piano range
        
        spectrogram = np.zeros((self.time_frames, self.freq_bins))
        
        # Generate harmonic content based on style
        brightness = style_params['brightness']  # 0-1
        richness = style_params['richness']     # 0-1  
        dynamics = style_params['dynamics']     # 0-1
        timing_stability = style_params['timing_stability']  # 0-1
        
        # Create time-varying piano content
        for t in range(self.time_frames):
            # Add timing jitter based on stability
            timing_noise = (1 - timing_stability) * self.rng.normal(0, 0.1)
            
            # Simulate note attacks and decays
            if t % 16 == 0:  # Note onsets every 16 frames
                attack_strength = 0.8 + dynamics * 0.4
            else:
                attack_strength *= 0.92  # Exponential decay
            
            # Add harmonic content
            for i, freq in enumerate(base_freqs):
                freq_bin = int((freq / 11025) * self.freq_bins)  # Map to mel bins
                if freq_bin < self.freq_bins:
                    # Fundamental
                    amplitude = attack_strength * (0.5 + 0.3 * self.rng.random())
                    spectrogram[t, freq_bin] += amplitude
                    
                    # Overtones (affected by brightness and richness)
                    for h in range(2, 6):  # Harmonics 2-5
                        h_bin = min(freq_bin + h * 8, self.freq_bins - 1)
                        h_amplitude = amplitude * (brightness * 0.3 + richness * 0.2) / h
                        spectrogram[t, h_bin] += h_amplitude
        
        # Add realistic noise and texture
        noise_level = 0.05 + (1 - timing_stability) * 0.1
        spectrogram += self.rng.normal(0, noise_level, spectrogram.shape)
        
        # Apply piano-like spectral envelope
        freq_envelope = np.exp(-np.linspace(0, 3, self.freq_bins))  # High-freq rolloff
        spectrogram *= freq_envelope[None, :]
        
        # Normalize to dB scale
        spectrogram = np.clip(spectrogram, 1e-8, None)
        spectrogram_db = 20 * np.log10(spectrogram)
        
        # Normalize to [0, 1] range
        min_db, max_db = -80, 0
        spectrogram_norm = (spectrogram_db - min_db) / (max_db - min_db)
        spectrogram_norm = np.clip(spectrogram_norm, 0, 1)
        
        return spectrogram_norm[..., np.newaxis]  # Add channel dimension
    
    def generate_perceptual_labels(self, style_params: Dict) -> np.ndarray:
        """Generate correlated perceptual labels"""
        labels = np.zeros(19)
        
        # Map style parameters to perceptual dimensions with realistic correlations
        labels[0] = style_params['timing_stability']  # Timing
        labels[1] = 0.3 + style_params['articulation'] * 0.7  # Articulation length
        labels[2] = style_params['articulation']  # Articulation softness
        labels[3] = style_params['pedal_wetness']  # Pedal wet/dry
        labels[4] = 1 - style_params['pedal_wetness'] * 0.6  # Pedal clarity
        
        # Timbre dimensions
        labels[5] = style_params['richness']  # Even/Colorful
        labels[6] = style_params['richness']  # Shallow/Rich
        labels[7] = style_params['brightness']  # Bright/Dark
        labels[8] = style_params['dynamics']  # Soft/Loud
        
        # Dynamic expression
        labels[9] = 0.3 + style_params['expression'] * 0.7  # Sophisticated/Raw
        labels[10] = style_params['dynamics']  # Dynamic range
        
        # Musical expression
        labels[11] = 1 - style_params['tempo_stability']  # Fast/Slow paced
        labels[12] = style_params['expression']  # Flat/Spacious
        labels[13] = style_params['timing_stability']  # Balanced
        labels[14] = style_params['expression']  # Pure/Dramatic
        
        # Emotion and mood
        labels[15] = 0.4 + style_params['mood_valence'] * 0.6  # Pleasant/Dark
        labels[16] = style_params['energy']  # Low/High energy
        labels[17] = style_params['expression']  # Honest/Imaginative
        
        # Overall interpretation
        labels[18] = np.mean([style_params['timing_stability'], style_params['expression'], 
                             style_params['dynamics']])  # Overall convincing
        
        # Add realistic noise and ensure [0,1] range
        labels += self.rng.normal(0, 0.05, 19)
        labels = np.clip(labels, 0, 1)
        
        return labels
    
    def generate_dataset(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generate complete synthetic dataset"""
        print(f"🎼 Generating {self.num_samples} synthetic piano performances...")
        
        spectrograms = []
        all_labels = []
        
        for i in tqdm(range(self.num_samples), desc="Generating data"):
            # Create diverse style parameters
            style_params = {
                'timing_stability': self.rng.beta(2, 2),  # Slightly unstable bias
                'articulation': self.rng.uniform(0.2, 0.9),
                'pedal_wetness': self.rng.uniform(0.1, 0.8),
                'brightness': self.rng.uniform(0.3, 0.8),
                'richness': self.rng.uniform(0.2, 0.9),
                'dynamics': self.rng.uniform(0.3, 0.9),
                'expression': self.rng.uniform(0.2, 0.8),
                'tempo_stability': self.rng.beta(3, 2),  # Generally stable
                'mood_valence': self.rng.uniform(0.3, 0.8),
                'energy': self.rng.uniform(0.2, 0.9)
            }
            
            # Generate spectrogram and labels
            spectrogram = self.generate_piano_spectrogram(style_params)
            labels = self.generate_perceptual_labels(style_params)
            
            spectrograms.append(spectrogram)
            all_labels.append(labels)
        
        spectrograms = np.stack(spectrograms)
        all_labels = np.stack(all_labels)
        
        print(f"✅ Generated dataset:")
        print(f"   Spectrograms: {spectrograms.shape}")
        print(f"   Labels: {all_labels.shape}")
        print(f"   Label ranges: [{all_labels.min():.3f}, {all_labels.max():.3f}]")
        
        return spectrograms, all_labels

# Generate synthetic dataset
data_generator = SyntheticPianoDataGenerator(num_samples=800)  # Reasonable size for Colab
X_synthetic, y_synthetic = data_generator.generate_dataset()

# 📊 Data Visualization

In [None]:
# Visualize synthetic data quality
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Show sample spectrograms
for i in range(3):
    axes[0, i].imshow(X_synthetic[i, :, :, 0].T, aspect='auto', origin='lower', cmap='viridis')
    axes[0, i].set_title(f'Sample Spectrogram {i+1}')
    axes[0, i].set_xlabel('Time Frames')
    axes[0, i].set_ylabel('Frequency Bins')

# Show label distributions
dimension_names = data_generator.dimensions
for i in range(3):
    dim_idx = i * 6  # Show every 6th dimension
    if dim_idx < len(dimension_names):
        axes[1, i].hist(y_synthetic[:, dim_idx], bins=30, alpha=0.7, edgecolor='black')
        axes[1, i].set_title(f'{dimension_names[dim_idx][:20]}...')
        axes[1, i].set_xlabel('Rating [0-1]')
        axes[1, i].set_ylabel('Count')

plt.tight_layout()
plt.show()

# Show correlation matrix of labels
plt.figure(figsize=(12, 10))
corr_matrix = np.corrcoef(y_synthetic.T)
sns.heatmap(corr_matrix, 
            xticklabels=[d[:15] + '...' for d in dimension_names],
            yticklabels=[d[:15] + '...' for d in dimension_names],
            cmap='coolwarm', center=0, annot=False)
plt.title('Synthetic Label Correlations')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

print(f"📈 Dataset Statistics:")
print(f"   Mean correlation between dimensions: {np.mean(np.abs(corr_matrix - np.eye(19))): .3f}")
print(f"   Strongest correlation: {np.max(corr_matrix - np.eye(19)):.3f}")
print(f"   Label standard deviations: {np.std(y_synthetic, axis=0).mean():.3f}")

# 🏋️ Training Infrastructure

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration for Colab validation"""
    learning_rate: float = 1e-3
    batch_size: int = 32
    epochs: int = 50  # Shorter for validation
    early_stopping_patience: int = 8
    val_split: float = 0.2
    test_split: float = 0.1

class TrainStateWithBatchStats(train_state.TrainState):
    """Training state that tracks batch normalization statistics"""
    batch_stats: dict

def create_train_state_fixed(model, learning_rate: float, input_shape: tuple):
    """Initialize training state with batch stats"""
    rng = jax.random.PRNGKey(42)
    dummy_input = jnp.ones(input_shape)
    
    variables = model.init(rng, dummy_input, training=False)
    params = variables['params']
    batch_stats = variables.get('batch_stats', {})
    
    optimizer = optax.adam(learning_rate)
    
    return TrainStateWithBatchStats.create(
        apply_fn=model.apply,
        params=params,
        batch_stats=batch_stats,
        tx=optimizer
    )

@jax.jit
def train_step_fixed(state, batch_x, batch_y, dropout_rng):
    """Fixed training step with mutable batch stats"""
    def loss_fn(params):
        predictions, new_model_state = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats}, 
            batch_x, 
            training=True, 
            rngs={'dropout': dropout_rng},
            mutable=['batch_stats']
        )
        
        # Multi-task MSE loss
        mse_loss = jnp.mean((predictions - batch_y) ** 2)
        
        # L2 regularization
        l2_loss = 0.001 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
        
        return mse_loss + l2_loss, (predictions, new_model_state)
    
    (loss, (predictions, new_model_state)), grads = jax.value_and_grad(
        loss_fn, has_aux=True
    )(state.params)
    
    state = state.apply_gradients(
        grads=grads, 
        batch_stats=new_model_state['batch_stats']
    )
    
    return state, loss, predictions

@jax.jit  
def eval_step_fixed(state, batch_x, batch_y):
    """Fixed evaluation step - JAX compatible correlation calculation"""
    predictions = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats}, 
        batch_x, 
        training=False
    )
    loss = jnp.mean((predictions - batch_y) ** 2)
    
    # JAX-compatible correlation calculation for each dimension
    correlations = []
    for i in range(predictions.shape[1]):
        pred_i = predictions[:, i]
        true_i = batch_y[:, i]
        
        # Center the data
        pred_centered = pred_i - jnp.mean(pred_i)
        true_centered = true_i - jnp.mean(true_i)
        
        # Calculate standard deviations
        pred_std = jnp.std(pred_centered)
        true_std = jnp.std(true_centered)
        
        # Calculate correlation using JAX-compatible operations
        # Use jnp.where instead of if/else for JAX compatibility
        numerator = jnp.mean(pred_centered * true_centered)
        denominator = pred_std * true_std
        
        # Handle division by zero case
        correlation = jnp.where(
            denominator > 1e-8,  # Safe threshold
            numerator / denominator,
            0.0  # Return 0 if no variance
        )
        
        # Clamp to valid correlation range [-1, 1]
        correlation = jnp.clip(correlation, -1.0, 1.0)
        correlations.append(correlation)
    
    return loss, jnp.array(correlations), predictions

def create_data_batches(X, y, batch_size, rng_key):
    """Create randomized batches"""
    n_samples = X.shape[0]
    indices = jax.random.permutation(rng_key, n_samples)
    
    # Trim to fit batch size
    n_batches = n_samples // batch_size
    indices = indices[:n_batches * batch_size]
    indices = indices.reshape(n_batches, batch_size)
    
    return X[indices], y[indices]

print("🏋️ Fixed training infrastructure ready (JAX-compatible correlations)!")

# 🚀 Model Training & Validation

In [None]:
def train_and_evaluate_model(model_name: str, model_class, config: TrainingConfig, 
                           X_train, y_train, X_val, y_val, X_test, y_test):
    """Train and evaluate a single model"""
    print(f"\n🧠 Training {model_name}...")
    print("=" * 50)
    
    # Initialize model
    if model_name == "RealTime":
        model = model_class(num_classes=19, width_multiplier=0.5)
    else:
        model = model_class(num_classes=19, base_filters=32)  # Smaller for Colab
    
    # Create training state
    input_shape = (config.batch_size, *X_train.shape[1:])
    state = create_train_state(model, config.learning_rate, input_shape)
    
    print(f"Model parameters: {sum(x.size for x in jax.tree_util.tree_leaves(state.params)):,}")
    
    # Training loop
    train_losses = []
    val_losses = []
    val_correlations = []
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    rng = jax.random.PRNGKey(42)
    
    for epoch in range(config.epochs):
        rng, epoch_rng, dropout_rng = jax.random.split(rng, 3)
        
        # Training phase
        X_train_batches, y_train_batches = create_data_batches(
            X_train, y_train, config.batch_size, epoch_rng
        )
        
        epoch_train_loss = 0.0
        for i in range(X_train_batches.shape[0]):
            batch_rng = jax.random.fold_in(dropout_rng, i)
            state, batch_loss, _ = train_step(
                state, X_train_batches[i], y_train_batches[i], batch_rng
            )
            epoch_train_loss += float(batch_loss)
        
        avg_train_loss = epoch_train_loss / X_train_batches.shape[0]
        
        # Validation phase
        X_val_batches, y_val_batches = create_data_batches(
            X_val, y_val, config.batch_size, epoch_rng
        )
        
        epoch_val_loss = 0.0
        epoch_correlations = []
        
        for i in range(X_val_batches.shape[0]):
            val_loss, correlations, _ = eval_step(
                state, X_val_batches[i], y_val_batches[i]
            )
            epoch_val_loss += float(val_loss)
            epoch_correlations.append(correlations)
        
        avg_val_loss = epoch_val_loss / X_val_batches.shape[0]
        avg_correlations = np.mean(np.stack(epoch_correlations), axis=0)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_correlations.append(avg_correlations)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        if epoch % 10 == 0 or epoch < 5:
            avg_corr = np.mean(avg_correlations)
            print(f"Epoch {epoch:2d} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | Corr: {avg_corr:.3f}")
        
        # Early stopping
        if patience_counter >= config.early_stopping_patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Final test evaluation
    X_test_batches, y_test_batches = create_data_batches(
        X_test, y_test, config.batch_size, jax.random.PRNGKey(0)
    )
    
    test_losses = []
    test_correlations = []
    
    for i in range(X_test_batches.shape[0]):
        test_loss, correlations, _ = eval_step(
            state, X_test_batches[i], y_test_batches[i]
        )
        test_losses.append(float(test_loss))
        test_correlations.append(correlations)
    
    final_test_loss = np.mean(test_losses)
    final_test_corr = np.mean(np.stack(test_correlations), axis=0)
    
    results = {
        'model_name': model_name,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_correlations': val_correlations,
        'test_loss': final_test_loss,
        'test_correlations': final_test_corr,
        'avg_test_correlation': float(np.mean(final_test_corr)),
        'final_epoch': epoch,
        'params': state.params
    }
    
    print(f"\n✅ {model_name} Final Results:")
    print(f"   Test Loss: {final_test_loss:.4f}")
    print(f"   Avg Correlation: {results['avg_test_correlation']:.3f}")
    print(f"   Best Dimensions: {np.argsort(final_test_corr)[-3:]}")
    
    return results

# Data splitting
n_samples = X_synthetic.shape[0]
test_size = int(n_samples * 0.1)
val_size = int(n_samples * 0.2)

# Random permutation for splitting
indices = np.random.permutation(n_samples)
test_idx = indices[:test_size]
val_idx = indices[test_size:test_size + val_size] 
train_idx = indices[test_size + val_size:]

# Convert to JAX arrays
X_train = jnp.array(X_synthetic[train_idx])
y_train = jnp.array(y_synthetic[train_idx])
X_val = jnp.array(X_synthetic[val_idx])
y_val = jnp.array(y_synthetic[val_idx])
X_test = jnp.array(X_synthetic[test_idx])
y_test = jnp.array(y_synthetic[test_idx])

print(f"📊 Data splits:")
print(f"   Train: {X_train.shape[0]} samples")
print(f"   Validation: {X_val.shape[0]} samples")
print(f"   Test: {X_test.shape[0]} samples")

config = TrainingConfig()
print(f"\n⚙️ Training config: {config.epochs} epochs, lr={config.learning_rate}, batch_size={config.batch_size}")

# 🏆 Model Architecture Comparison

In [None]:
def train_and_evaluate_model_fixed(model_name: str, model_class, config: TrainingConfig, 
                           X_train, y_train, X_val, y_val, X_test, y_test):
    """Fixed train and evaluate function with proper batch stats handling"""
    print(f"\n🧠 Training {model_name}...")
    print("=" * 50)
    
    # Initialize model with correct parameters
    if model_name == "RealTime CNN":
        model = model_class(num_classes=19, width_multiplier=0.5)
    else:
        model = model_class(num_classes=19, base_filters=32)  # Smaller for Colab
    
    # Create training state with fixed function
    input_shape = (config.batch_size, *X_train.shape[1:])
    state = create_train_state_fixed(model, config.learning_rate, input_shape)
    
    print(f"Model parameters: {sum(x.size for x in jax.tree_util.tree_leaves(state.params)):,}")
    
    # Training loop
    train_losses = []
    val_losses = []
    val_correlations = []
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    rng = jax.random.PRNGKey(42)
    
    for epoch in range(config.epochs):
        rng, epoch_rng, dropout_rng = jax.random.split(rng, 3)
        
        # Training phase
        X_train_batches, y_train_batches = create_data_batches(
            X_train, y_train, config.batch_size, epoch_rng
        )
        
        epoch_train_loss = 0.0
        for i in range(X_train_batches.shape[0]):
            batch_rng = jax.random.fold_in(dropout_rng, i)
            state, batch_loss, _ = train_step_fixed(
                state, X_train_batches[i], y_train_batches[i], batch_rng
            )
            epoch_train_loss += float(batch_loss)
        
        avg_train_loss = epoch_train_loss / X_train_batches.shape[0]
        
        # Validation phase
        X_val_batches, y_val_batches = create_data_batches(
            X_val, y_val, config.batch_size, epoch_rng
        )
        
        epoch_val_loss = 0.0
        epoch_correlations = []
        
        for i in range(X_val_batches.shape[0]):
            val_loss, correlations, _ = eval_step_fixed(
                state, X_val_batches[i], y_val_batches[i]
            )
            epoch_val_loss += float(val_loss)
            epoch_correlations.append(correlations)
        
        avg_val_loss = epoch_val_loss / X_val_batches.shape[0]
        avg_correlations = np.mean(np.stack(epoch_correlations), axis=0)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_correlations.append(avg_correlations)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        if epoch % 10 == 0 or epoch < 5:
            avg_corr = np.mean(avg_correlations)
            print(f"Epoch {epoch:2d} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f} | Corr: {avg_corr:.3f}")
        
        # Early stopping
        if patience_counter >= config.early_stopping_patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Final test evaluation
    X_test_batches, y_test_batches = create_data_batches(
        X_test, y_test, config.batch_size, jax.random.PRNGKey(0)
    )
    
    test_losses = []
    test_correlations = []
    
    for i in range(X_test_batches.shape[0]):
        test_loss, correlations, _ = eval_step_fixed(
            state, X_test_batches[i], y_test_batches[i]
        )
        test_losses.append(float(test_loss))
        test_correlations.append(correlations)
    
    final_test_loss = np.mean(test_losses)
    final_test_corr = np.mean(np.stack(test_correlations), axis=0)
    
    results = {
        'model_name': model_name,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_correlations': val_correlations,
        'test_loss': final_test_loss,
        'test_correlations': final_test_corr,
        'avg_test_correlation': float(np.mean(final_test_corr)),
        'final_epoch': epoch,
        'params': state.params
    }
    
    print(f"\n✅ {model_name} Final Results:")
    print(f"   Test Loss: {final_test_loss:.4f}")
    print(f"   Avg Correlation: {results['avg_test_correlation']:.3f}")
    print(f"   Best Dimensions: {np.argsort(final_test_corr)[-3:]}")
    
    return results

# Test all architectures with fixed function
model_configs = [
    ("Standard CNN", PianoSpectroCNN),
    ("RealTime CNN", RealTimePianoCNN)
]

results = {}

for model_name, model_class in model_configs:
    try:
        result = train_and_evaluate_model_fixed(
            model_name, model_class, config,
            X_train, y_train, X_val, y_val, X_test, y_test
        )
        results[model_name] = result
        
    except Exception as e:
        print(f"❌ {model_name} failed: {e}")
        import traceback
        traceback.print_exc()
        results[model_name] = None

print(f"\n🎯 ARCHITECTURE COMPARISON COMPLETE!")

# 📈 Results Analysis & Visualization

In [None]:
# Training curves comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
for model_name, result in results.items():
    if result is not None:
        axes[0, 0].plot(result['train_losses'], label=f'{model_name} Train', alpha=0.8)
        axes[0, 1].plot(result['val_losses'], label=f'{model_name} Val', alpha=0.8)

axes[0, 0].set_title('Training Losses')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('MSE Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].set_title('Validation Losses')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('MSE Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Final performance comparison
model_names = [name for name, result in results.items() if result is not None]
test_losses = [result['test_loss'] for name, result in results.items() if result is not None]
avg_correlations = [result['avg_test_correlation'] for name, result in results.items() if result is not None]

x_pos = np.arange(len(model_names))

axes[1, 0].bar(x_pos, test_losses, alpha=0.7, color=['skyblue', 'orange'])
axes[1, 0].set_title('Final Test Loss (Lower = Better)')
axes[1, 0].set_xlabel('Model')
axes[1, 0].set_ylabel('MSE Loss')
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(model_names, rotation=45)

# Add value labels
for i, v in enumerate(test_losses):
    axes[1, 0].text(i, v + 0.001, f'{v:.4f}', ha='center', va='bottom')

axes[1, 1].bar(x_pos, avg_correlations, alpha=0.7, color=['lightgreen', 'gold'])
axes[1, 1].set_title('Average Test Correlation (Higher = Better)')
axes[1, 1].set_xlabel('Model')
axes[1, 1].set_ylabel('Correlation')
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(model_names, rotation=45)

# Add value labels
for i, v in enumerate(avg_correlations):
    axes[1, 1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Per-dimension correlation analysis
if len(results) > 0:
    plt.figure(figsize=(15, 8))
    
    dimension_names = data_generator.dimensions
    x_dims = np.arange(19)
    
    for i, (model_name, result) in enumerate(results.items()):
        if result is not None:
            plt.bar(x_dims + i*0.35, result['test_correlations'], 
                   width=0.35, alpha=0.7, label=model_name)
    
    plt.title('Per-Dimension Test Correlations')
    plt.xlabel('Perceptual Dimension')
    plt.ylabel('Correlation')
    plt.xticks(x_dims + 0.175, [d[:15] + '...' for d in dimension_names], rotation=45, ha='right')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

print("\n📊 VALIDATION RESULTS SUMMARY:")
print("=" * 60)

for model_name, result in results.items():
    if result is not None:
        print(f"\n{model_name}:")
        print(f"  ✓ Training completed successfully")
        print(f"  ✓ Final test loss: {result['test_loss']:.4f}")
        print(f"  ✓ Average correlation: {result['avg_test_correlation']:.3f}")
        print(f"  ✓ Converged in {result['final_epoch']} epochs")
    else:
        print(f"\n{model_name}: ❌ FAILED")

# Success criteria check
success_criteria = {
    'models_trained': len([r for r in results.values() if r is not None]) >= 1,
    'loss_decreased': all(r['val_losses'][-1] < r['val_losses'][0] for r in results.values() if r is not None),
    'correlations_reasonable': all(r['avg_test_correlation'] > 0.5 for r in results.values() if r is not None)
}

print(f"\n🎯 SUCCESS CRITERIA:")
for criterion, passed in success_criteria.items():
    status = "✅ PASS" if passed else "❌ FAIL"
    print(f"   {criterion}: {status}")

all_passed = all(success_criteria.values())
print(f"\n🚀 Overall Status: {'READY FOR REAL DATA!' if all_passed else 'NEEDS DEBUGGING'}")

# 🎵 Next Steps: Chopin/Liszt Data Collection

## If validation successful:

### 1. Audio Collection Strategy
```python
# Target repertoire for diversity
chopin_pieces = [
    "Etude Op.10 No.1-12",  # Technical variety
    "Nocturne Op.9 No.1-3", # Expressive range
    "Ballade No.1-4",       # Structural complexity
    "Prelude Op.28 (selection)" # Stylistic diversity
]

liszt_pieces = [
    "Hungarian Rhapsody No.2, 6", # Dramatic expression
    "Liebestraum No.3",          # Lyrical style
    "Transcendental Etude 4, 10" # Technical brilliance
]
```

### 2. Recording Sources
- **YouTube**: Concert recordings, competitions
- **IMSLP**: Public domain recordings
- **Personal**: Your own performances
- **Target**: 10-15 different interpreters per piece

### 3. Labeling Interface
```python
# Simple rating interface
import ipywidgets as widgets
from IPython.display import Audio, display

def create_rating_interface(audio_file, dimensions):
    # Audio playback + 19 sliders for ratings
    pass
```

### 4. Quality Control
- Rate same performance twice (consistency check)
- Use subset of PercePiano for calibration
- Multiple raters for subset validation

## Expected Timeline
- **Week 1**: Audio collection (20-30 pieces)
- **Week 2**: Initial labeling + rating interface
- **Week 3**: Full dataset creation (100+ performances)
- **Week 4**: Model training on real Chopin/Liszt data

**Ready to move forward with real data collection! 🎹**

# 💾 Model Export for Real Data

In [None]:
# Save validation results and best model
validation_summary = {
    'synthetic_data_stats': {
        'num_samples': X_synthetic.shape[0],
        'spectrogram_shape': X_synthetic.shape[1:],
        'num_dimensions': y_synthetic.shape[1]
    },
    'model_results': {}
}

# Find best model
best_model = None
best_correlation = 0

for model_name, result in results.items():
    if result is not None:
        validation_summary['model_results'][model_name] = {
            'test_loss': float(result['test_loss']),
            'avg_correlation': float(result['avg_test_correlation']),
            'final_epoch': int(result['final_epoch'])
        }
        
        if result['avg_test_correlation'] > best_correlation:
            best_correlation = result['avg_test_correlation']
            best_model = model_name

validation_summary['best_model'] = best_model
validation_summary['validation_passed'] = all(success_criteria.values())

# Export results
with open('piano_colab_validation_results.json', 'w') as f:
    json.dump(validation_summary, f, indent=2)

print(f"\n💾 Results exported to: piano_colab_validation_results.json")
print(f"\n🏆 Best performing model: {best_model} (correlation: {best_correlation:.3f})")

if validation_summary['validation_passed']:
    print(f"\n🎉 VALIDATION SUCCESSFUL! Ready for Chopin/Liszt data collection!")
    print(f"\n🎯 Next steps:")
    print(f"   1. Collect 20-30 diverse Chopin/Liszt recordings")
    print(f"   2. Create rating interface for perceptual labeling")
    print(f"   3. Train {best_model} on real data")
    print(f"   4. Compare with PercePiano baseline")
else:
    print(f"\n🔧 Validation issues detected - debug before proceeding")
    
# Download results file
from google.colab import files
files.download('piano_colab_validation_results.json')