# Piano Performance Analysis with CNN Models 🎹

**Complete training pipeline for predicting 19 perceptual dimensions using JAX/Flax CNN architectures**

This notebook trains deep learning models to analyze piano performances across dimensions like timing, articulation, dynamics, and musical expression using the PercePiano dataset.

## Architecture Options
- **Standard CNN**: Basic spectral analysis
- **Fusion CNN**: Multi-spectral feature fusion (mel + MFCC + chroma)
- **Realtime CNN**: Optimized for fast inference

## Expected Training Time
- GPU (T4): ~2-3 hours for complete training
- Dataset: 1,202 performances with 19 perceptual labels


## 1. Environment Setup & GPU Verification

In [None]:
# Install required packages
!pip install jax[cuda] flax optax librosa wandb soundfile matplotlib seaborn -q

# Verify GPU availability
import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.lib.xla_bridge.get_backend().platform}")

gpu_available = len([d for d in jax.devices() if d.device_kind == 'gpu']) > 0
print(f"🚀 GPU Available: {gpu_available}")

if not gpu_available:
    print("⚠️  Enable GPU: Runtime → Change runtime type → Hardware accelerator → GPU")

## 2. Upload Project Files

Upload your two zip files:
1. `piano-analysis-colab.zip` (source code)
2. `percepiano-data.zip` (dataset)

In [None]:
from google.colab import files
import zipfile
import os

print("📁 Upload your project files...")
uploaded = files.upload()

# Extract uploaded files
for filename in uploaded.keys():
    if filename.endswith('.zip'):
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('/content/')
        print(f"✅ Extracted {filename}")
        os.remove(filename)  # Clean up

# Verify structure
print("\n📂 Project structure:")
!ls -la /content/src/ | head -10
!ls -la /content/data/

## 3. Setup Project Environment

In [None]:
# Add src to Python path
import sys
sys.path.append('/content/src')

# Import all necessary libraries
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import matplotlib.pyplot as plt
import librosa
import wandb
import json
from pathlib import Path

# Import project modules
from piano_cnn_jax import get_piano_model
from training_pipeline_jax import PianoTrainer, TrainingConfig
from dataset_analysis import load_perceptual_labels, PERCEPTUAL_DIMENSIONS
from audio_preprocessing import PianoAudioPreprocessor

print("✅ All imports successful!")
print(f"📊 Available dimensions: {len(PERCEPTUAL_DIMENSIONS)}")
print(f"🎵 Sample dimensions: {PERCEPTUAL_DIMENSIONS[:5]}")

## 4. Data Verification & Quick Analysis

In [None]:
# Verify dataset
labels = load_perceptual_labels('/content/data/label_2round_mean_reg_19_with0_rm_highstd0.json')
print(f"✅ Loaded {len(labels)} performances")

# Quick audio verification
audio_file = Path('/content/data/Beethoven_WoO80_var27_8bars_3_15.wav')
if audio_file.exists():
    y, sr = librosa.load(audio_file)
    print(f"✅ Audio sample: {len(y)/sr:.1f}s at {sr}Hz")
    
    # Quick visualization
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(y[:sr*2])  # First 2 seconds
    plt.title('Waveform (first 2s)')
    
    plt.subplot(1, 2, 2)
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
    plt.imshow(librosa.power_to_db(S), aspect='auto', origin='lower')
    plt.title('Mel Spectrogram')
    plt.tight_layout()
    plt.show()
else:
    print("❌ Audio file not found")

# Dataset statistics
all_ratings = np.array([ratings[:-1] for ratings in labels.values()])
print(f"\n📈 Dataset statistics:")
print(f"   Shape: {all_ratings.shape}")
print(f"   Rating range: [{all_ratings.min():.3f}, {all_ratings.max():.3f}]")
print(f"   Mean rating: {all_ratings.mean():.3f}")

## 5. Configure Training

Choose your training configuration:

In [None]:
# Training configuration - adjust based on your needs
config = TrainingConfig(
    model_architecture="standard",  # Options: "standard", "fusion", "realtime"
    learning_rate=1e-3,
    batch_size=16,  # Adjust based on GPU memory
    epochs=50,      # Reduce for quick testing, increase for full training
    early_stopping_patience=10,
    
    # Data paths (Colab)
    data_path="/content/data",
    checkpoint_path="/content/checkpoints",
    results_path="/content/results",
    
    # Audio processing
    sample_rate=22050,
    n_mels=128,
    n_fft=2048,
    hop_length=512,
    
    # Model architecture
    base_filters=64,
    dropout_rate=0.2
)

print("🔧 Training configuration:")
for key, value in config.__dict__.items():
    print(f"   {key}: {value}")

# Create directories
os.makedirs(config.checkpoint_path, exist_ok=True)
os.makedirs(config.results_path, exist_ok=True)

## 6. Model Architecture Test

In [None]:
# Test model architecture before full training
print(f"🧪 Testing {config.model_architecture} architecture...")

model = get_piano_model(
    architecture=config.model_architecture,
    num_classes=19,
    base_filters=config.base_filters,
    dropout_rate=config.dropout_rate
)

# Test with dummy input
rng = jax.random.PRNGKey(0)
dummy_input = jax.random.normal(rng, (2, 128, 128, 1))

params = model.init(rng, dummy_input, training=False)
output = model.apply(params, dummy_input, training=False)

param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))

print(f"✅ Model test successful:")
print(f"   Architecture: {config.model_architecture}")
print(f"   Parameters: {param_count:,}")
print(f"   Input shape: {dummy_input.shape}")
print(f"   Output shape: {output.shape}")
print(f"   Output range: [{output.min():.3f}, {output.max():.3f}]")

## 7. Initialize Weights & Biases (Optional)

Login to W&B for experiment tracking:

In [None]:
# Optional: Setup wandb for experiment tracking
try:
    wandb.login()
    use_wandb = True
    print("✅ Weights & Biases connected")
except:
    use_wandb = False
    print("⚠️  Skipping W&B - training will continue without logging")

if use_wandb:
    wandb.init(
        project="piano-performance-analysis",
        name=f"piano-cnn-{config.model_architecture}",
        config=config.__dict__
    )

## 8. Start Training 🚀

This is the main training cell. Expected time: ~2-3 hours with GPU.

In [None]:
# Initialize trainer
print("🚀 Initializing trainer...")
trainer = PianoTrainer(config)

print(f"📊 Training setup:")
print(f"   Model parameters: {sum(x.size for x in jax.tree_util.tree_leaves(trainer.state.params)):,}")
print(f"   Training samples: {len(trainer.train_data['labels'])}")
print(f"   Validation samples: {len(trainer.val_data['labels'])}")
print(f"   Test samples: {len(trainer.test_data['labels'])}")

print("\n🎯 Starting training...")
print("This may take 2-3 hours. Monitor progress in W&B dashboard if enabled.")

# Start training
test_results = trainer.train()

print("\n🎉 Training completed!")
print(f"📈 Final test results:")
for key, value in test_results.items():
    if isinstance(value, (int, float)):
        print(f"   {key}: {value:.4f}")
    elif isinstance(value, (list, np.ndarray)):
        print(f"   {key}: mean={np.mean(value):.4f}, std={np.std(value):.4f}")

## 9. Results Analysis & Visualization

In [None]:
# Load and visualize training results
results_file = Path(config.results_path) / "training_results.json"

if results_file.exists():
    with open(results_file) as f:
        training_results = json.load(f)
    
    # Create comprehensive visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training curves
    axes[0, 0].plot(training_results['train_loss'], label='Train', alpha=0.7)
    axes[0, 0].plot(training_results['val_loss'], label='Validation', alpha=0.7)
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('MSE Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Validation correlations over time
    if 'val_correlations' in training_results:
        val_corrs = np.array(training_results['val_correlations'])
        axes[0, 1].plot(val_corrs.mean(axis=1), label='Mean Correlation', color='green')
        axes[0, 1].fill_between(range(len(val_corrs)), 
                               val_corrs.mean(axis=1) - val_corrs.std(axis=1),
                               val_corrs.mean(axis=1) + val_corrs.std(axis=1),
                               alpha=0.2, color='green')
        axes[0, 1].set_title('Validation Correlations')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Correlation')
        axes[0, 1].grid(True, alpha=0.3)
    
    # Final test correlations by dimension
    if 'test_correlations' in test_results:
        test_corrs = test_results['test_correlations']
        x_pos = np.arange(len(test_corrs))
        colors = ['green' if c > 0.5 else 'orange' if c > 0.3 else 'red' for c in test_corrs]
        
        bars = axes[1, 0].bar(x_pos, test_corrs, color=colors, alpha=0.7)
        axes[1, 0].set_title('Test Correlations by Dimension')
        axes[1, 0].set_xlabel('Perceptual Dimension')
        axes[1, 0].set_ylabel('Correlation')
        axes[1, 0].set_xticks(x_pos[::2])  # Show every other label
        axes[1, 0].set_xticklabels([PERCEPTUAL_DIMENSIONS[i][:15] for i in x_pos[::2]], 
                                  rotation=45, ha='right')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Add correlation threshold lines
        axes[1, 0].axhline(y=0.5, color='green', linestyle='--', alpha=0.5, label='Good (>0.5)')
        axes[1, 0].axhline(y=0.3, color='orange', linestyle='--', alpha=0.5, label='Fair (>0.3)')
        axes[1, 0].legend()
    
    # Model performance summary
    axes[1, 1].axis('off')
    summary_text = f"""
Model: {config.model_architecture.upper()} CNN
Parameters: {sum(x.size for x in jax.tree_util.tree_leaves(trainer.state.params)):,}

Dataset:
• Training: {len(trainer.train_data['labels'])} samples
• Validation: {len(trainer.val_data['labels'])} samples  
• Test: {len(trainer.test_data['labels'])} samples

Final Performance:
• Test Loss: {test_results.get('test_loss', 'N/A'):.4f}
• Mean Correlation: {np.mean(test_results.get('test_correlations', [0])):.3f}
• Strong Correlations (>0.5): {sum(1 for c in test_results.get('test_correlations', []) if c > 0.5)}/19

Training Time: ~{len(training_results.get('train_loss', []))} epochs
    """
    axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, 
                    fontsize=10, verticalalignment='top', fontfamily='monospace',
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.1))
    
    plt.tight_layout()
    plt.show()
    
    # Print top and bottom performing dimensions
    if 'test_correlations' in test_results:
        correlations = test_results['test_correlations']
        sorted_dims = sorted(zip(PERCEPTUAL_DIMENSIONS, correlations), 
                           key=lambda x: x[1], reverse=True)
        
        print("\n🏆 Top 5 Best Predicted Dimensions:")
        for i, (dim, corr) in enumerate(sorted_dims[:5]):
            print(f"   {i+1}. {dim}: {corr:.3f}")
        
        print("\n🎯 Bottom 5 Dimensions (Need Improvement):")
        for i, (dim, corr) in enumerate(sorted_dims[-5:]):
            print(f"   {i+1}. {dim}: {corr:.3f}")
        
else:
    print("❌ Training results file not found")

## 10. Save & Download Results

Package your results for download:

In [None]:
# Create results package
import shutil
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
results_name = f"piano_cnn_results_{config.model_architecture}_{timestamp}"

# Create comprehensive results directory
results_dir = f"/content/{results_name}"
os.makedirs(results_dir, exist_ok=True)

# Copy all important files
files_to_save = [
    (f"{config.results_path}/training_results.json", "training_results.json"),
    (f"{config.checkpoint_path}/best_model", "model_checkpoint/"),
    ("/content/Piano_CNN_Training_Complete.ipynb", "training_notebook.ipynb")
]

for src, dst in files_to_save:
    src_path = Path(src)
    dst_path = Path(results_dir) / dst
    
    if src_path.exists():
        if src_path.is_dir():
            shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
        else:
            shutil.copy2(src_path, dst_path)
        print(f"✅ Saved: {dst}")
    else:
        print(f"⚠️  Not found: {src}")

# Create summary report
summary_report = {
    "model_architecture": config.model_architecture,
    "training_config": config.__dict__,
    "final_results": test_results,
    "training_timestamp": timestamp,
    "dataset_info": {
        "total_performances": len(labels),
        "perceptual_dimensions": len(PERCEPTUAL_DIMENSIONS),
        "train_samples": len(trainer.train_data['labels']),
        "val_samples": len(trainer.val_data['labels']),
        "test_samples": len(trainer.test_data['labels'])
    }
}

with open(f"{results_dir}/experiment_summary.json", 'w') as f:
    json.dump(summary_report, f, indent=2)

# Create zip file for download
shutil.make_archive(results_name, 'zip', '/content', results_name)

print(f"\n📦 Results package created: {results_name}.zip")
print(f"📊 Package size: {os.path.getsize(f'{results_name}.zip') / 1024 / 1024:.1f} MB")

# Download the results
print("\n⬇️  Downloading results...")
files.download(f"{results_name}.zip")

print("\n🎉 Training and analysis complete!")
print("\n📋 Next steps for your portfolio:")
print("   1. Extract and analyze the downloaded results")
print("   2. Document model performance and insights")
print("   3. Try different architectures (fusion, realtime)")
print("   4. Experiment with hyperparameter tuning")
print("   5. Extend to new piano repertoire (Chopin, Liszt)")

## 11. Quick Model Inference Test

Test the trained model on sample audio:

In [None]:
# Test trained model on sample audio
if Path("/content/data/Beethoven_WoO80_var27_8bars_3_15.wav").exists():
    print("🎹 Testing trained model on sample audio...")
    
    # Load and process audio
    preprocessor = PianoAudioPreprocessor()
    audio_result = preprocessor.process_audio_file(
        Path("/content/data/Beethoven_WoO80_var27_8bars_3_15.wav"),
        "beethoven_sample"
    )
    
    # Extract features for prediction
    features = np.array([list(audio_result['scalar_features'].values())])
    
    # Make prediction (simplified - normally you'd use spectrograms)
    print(f"\n🎯 Sample prediction (using audio features):")
    print(f"   Audio duration: {audio_result['duration']:.1f}s")
    print(f"   Features extracted: {len(audio_result['scalar_features'])}")
    
    # Show some interesting correlations from your analysis
    feature_names = list(audio_result['scalar_features'].keys())
    for i, (dim, feature) in enumerate(zip(PERCEPTUAL_DIMENSIONS[:5], feature_names[:5])):
        print(f"   {dim}: {features[0][i]:.3f} (audio feature: {feature})")
    
    print("\n📝 Note: Full CNN prediction would use spectrograms as input")
    print("    This demonstrates the feature extraction pipeline")
else:
    print("⚠️  Sample audio not found for inference test")