# Score-Aligned Piano Performance Evaluation (v2 - Hierarchical)

**Goal**: Train model with score alignment features to improve R-squared from 0.10 to 0.30-0.40

## Key Improvements in This Version

1. **Hierarchical Encoder (HAN)**: Note -> Beat -> Measure hierarchy matching PercePiano architecture
2. **Expanded Features**: 20 per-note features (was 6) including timing, articulation, dynamics, pitch
3. **Note Locations**: Beat/measure/voice indices for proper hierarchical aggregation
4. **Fixed Dimension Handling**: Proper 19 PercePiano dimensions
5. **Pre-trained MIDI Encoder**: Loads encoder_pretrained.pt for better initialization
6. **Pre-flight Validation**: Validates data, scores, and encoder before training

## What You Need on Google Drive

- `gdrive:percepiano_data/` containing:
  - `percepiano_train.json`, `percepiano_val.json`, `percepiano_test.json`
  - `PercePiano/virtuoso/data/all_2rounds/` (performance MIDI files)
  - `PercePiano/virtuoso/data/score_xml/` (MusicXML score files)
- `gdrive:crescendai_checkpoints/midi_pretrain/encoder_pretrained.pt` (pre-trained encoder)

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

# Rclone Reminder

In [None]:
# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

In [None]:
import os
from pathlib import Path
import subprocess

# Paths
CHECKPOINT_ROOT = '/tmp/checkpoints/score_aligned'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/score_aligned'
GDRIVE_DATA_PATH = 'gdrive:percepiano_data'
DATA_ROOT = Path('/tmp/percepiano_data')

print("="*70)
print("SETUP: CHECKPOINTS AND DATA")
print("="*70)

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
print("\nChecking rclone configuration...")
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("  rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
    
    # Restore existing checkpoints
    print("\nRestoring checkpoints from Google Drive (if any)...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
        capture_output=False
    )
else:
    print("  rclone 'gdrive' remote: NOT CONFIGURED")
    print("  Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nCheckpoint directory: {CHECKPOINT_ROOT}")
print(f"rclone available: {RCLONE_AVAILABLE}")

## Step 2: Download Data with Scores

In [None]:
from pathlib import Path
import subprocess
import json

DATA_ROOT = Path('/tmp/percepiano_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)

PRETRAIN_DIR = Path('/tmp/checkpoints/midi_pretrain')
PRETRAIN_DIR.mkdir(parents=True, exist_ok=True)

# Download PercePiano data
train_file = DATA_ROOT / 'percepiano_train.json'
if train_file.exists():
    print(f"Data already exists at {DATA_ROOT}")
else:
    print("Downloading PercePiano data from Google Drive...")
    result = subprocess.run(
        ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
        capture_output=False
    )

# Download pre-trained encoder
pretrained_path = PRETRAIN_DIR / 'encoder_pretrained.pt'
if not pretrained_path.exists():
    print("\nDownloading pre-trained MIDI encoder from Google Drive...")
    subprocess.run(
        ['rclone', 'copy', 'gdrive:crescendai_checkpoints/midi_pretrain', str(PRETRAIN_DIR), '--progress'],
        capture_output=False
    )

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

for split in ['train', 'val', 'test']:
    path = DATA_ROOT / f'percepiano_{split}.json'
    if path.exists():
        with open(path) as f:
            data = json.load(f)
        has_scores = sum(1 for s in data if s.get('score_path'))
        print(f"{split}: {len(data)} samples ({has_scores} with score paths)")
    else:
        print(f"ERROR: {path} not found!")

# Check MIDI and score files
midi_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'all_2rounds'
score_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'score_xml'

if midi_dir.exists():
    midi_files = list(midi_dir.glob('*.mid'))
    print(f"\nMIDI files: {len(midi_files)}")
else:
    raise FileNotFoundError(f"MIDI directory not found at {midi_dir}")

if score_dir.exists():
    score_files = list(score_dir.glob('*.musicxml'))
    print(f"Score files: {len(score_files)}")
    if len(score_files) == 0:
        raise FileNotFoundError(f"No MusicXML files found in {score_dir}")
else:
    raise FileNotFoundError(
        f"Score directory not found at {score_dir}\n"
        "Run: rclone copy gdrive:percepiano_data/PercePiano/virtuoso/data/score_xml/ {score_dir}/"
    )

# Verify pre-trained encoder
if pretrained_path.exists():
    size_mb = pretrained_path.stat().st_size / 1e6
    print(f"\nPre-trained encoder: {pretrained_path} ({size_mb:.1f} MB)")
else:
    raise FileNotFoundError(
        f"Pre-trained encoder not found at {pretrained_path}\n"
        "Run: rclone copy gdrive:crescendai_checkpoints/midi_pretrain/ {PRETRAIN_DIR}/"
    )

print("\n[OK] All required files present")

## Step 3: Update JSON Files for Thunder Compute Paths

In [None]:
import json
from pathlib import Path

DATA_ROOT = Path('/tmp/percepiano_data')
MIDI_DIR = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'all_2rounds'
SCORE_DIR = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'score_xml'

# All 19 PercePiano dimensions
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity", "timbre_variety",
    "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

def update_paths_for_thunder(data_root: Path):
    """Update paths in JSON files for Thunder Compute environment."""
    
    for split in ['train', 'val', 'test']:
        path = data_root / f'percepiano_{split}.json'
        
        with open(path) as f:
            data = json.load(f)
        
        for sample in data:
            # Update MIDI path
            filename = Path(sample['midi_path']).name
            sample['midi_path'] = str(MIDI_DIR / filename)
            
            # Keep score_path as just filename (relative to SCORE_DIR)
            # The dataset loader will combine with SCORE_DIR
            
            # Make sure scores dict uses all 19 dimensions
            if 'percepiano_scores' in sample:
                pp_scores = sample['percepiano_scores'][:19]
                sample['scores'] = {
                    dim: pp_scores[i]
                    for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
                }
        
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
        
        print(f"Updated {split}: {len(data)} samples")

update_paths_for_thunder(DATA_ROOT)

# Verify
with open(DATA_ROOT / 'percepiano_train.json') as f:
    sample = json.load(f)[0]

print(f"\nSample MIDI path: {sample['midi_path']}")
print(f"Sample score path: {sample.get('score_path', 'N/A')}")
print(f"Dimensions: {len(sample['scores'])}")
print(f"MIDI exists: {Path(sample['midi_path']).exists()}")

In [None]:
# Pre-flight validation - FAIL FAST if requirements not met
from src.utils.preflight_validation import run_preflight_validation, PreflightValidationError

print("="*60)
print("PRE-FLIGHT VALIDATION")
print("="*60)

try:
    run_preflight_validation(
        data_dir=DATA_ROOT,
        score_dir=SCORE_DIR,
        pretrained_checkpoint=PRETRAIN_DIR / 'encoder_pretrained.pt',
        require_pretrained=True,  # Require pre-trained encoder
        min_score_coverage=0.95,  # Require 95% score coverage
    )
    print("\n[OK] Pre-flight validation PASSED - ready to train")
except PreflightValidationError as e:
    print(f"\n[VALIDATION FAILED]\n{e}")
    raise RuntimeError("Fix the issues above before training")

## Step 4: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Paths (defined earlier)
SCORE_DIR = Path('/tmp/percepiano_data/PercePiano/virtuoso/data/score_xml')
PRETRAINED_CHECKPOINT = Path('/tmp/checkpoints/midi_pretrain/encoder_pretrained.pt')

CONFIG = {
    # Data
    'data_dir': '/tmp/percepiano_data',
    'score_dir': str(SCORE_DIR),
    'pretrained_checkpoint': str(PRETRAINED_CHECKPOINT),
    
    # MIDI Encoder (768/12/12 to match MidiBERT)
    'midi_hidden_dim': 768,
    'midi_num_layers': 12,
    'midi_num_heads': 12,
    'max_seq_length': 512,
    
    # Score Encoder
    'score_hidden_dim': 256,
    'score_num_layers': 2,
    'score_note_features': 20,  # Expanded from 6 to 20 features per note
    'use_hierarchical_encoder': True,  # Use HAN-style note->beat->measure hierarchy
    
    # Fusion
    'fusion_type': 'gated',  # Options: 'concat', 'crossattn', 'gated'
    'fused_dim': 768,
    
    # Aggregation (PercePiano style)
    'attention_da': 128,
    'attention_r': 4,
    'head_hidden_dim': 256,
    'dropout': 0.1,
    
    # Training
    'batch_size': 4,
    'learning_rate': 1e-5,
    'weight_decay': 0.01,
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 1.0,
    'accumulate_grad_batches': 4,  # Effective batch size = 16
    'precision': '16-mixed',
    
    # Sequence lengths
    'max_score_notes': 1024,
    'max_tempo_segments': 256,
    
    # Checkpoints
    'checkpoint_dir': '/tmp/checkpoints/score_aligned',
    'gdrive_checkpoint': 'gdrive:crescendai_checkpoints/score_aligned',
    
    # Options
    'freeze_midi_encoder': False,
}

print("Training Configuration (Score-Aligned Model with Hierarchical Encoder):")
print("="*70)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")
print("="*70)
print("\nKey improvements in this version:")
print("  - Pre-trained MIDI encoder loaded at initialization")
print("  - Pre-flight validation ensures all data is present")
print("  - Hierarchical encoder (HAN): note -> beat -> measure hierarchy")
print("  - No fallback mode - fail fast if data missing")

## Step 5: Create DataLoaders with Score Features

In [None]:
from pathlib import Path
from src.data.percepiano_score_dataset import create_score_dataloaders

train_loader, val_loader, test_loader = create_score_dataloaders(
    data_dir=Path(CONFIG['data_dir']),
    score_dir=Path(CONFIG['score_dir']) if CONFIG['score_dir'] else None,
    batch_size=CONFIG['batch_size'],
    max_midi_seq_length=CONFIG['max_seq_length'],
    max_score_notes=CONFIG['max_score_notes'],
    num_workers=4,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Test a batch
batch = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  midi_tokens: {batch['midi_tokens'].shape}")
print(f"  score_note_features: {batch['score_note_features'].shape}")
print(f"  score_global_features: {batch['score_global_features'].shape}")
print(f"  score_tempo_curve: {batch['score_tempo_curve'].shape}")
print(f"  scores: {batch['scores'].shape}")

# Check for note_locations (new hierarchical features)
if 'note_locations_beat' in batch:
    print(f"\nHierarchical features (note_locations):")
    print(f"  note_locations_beat: {batch['note_locations_beat'].shape}")
    print(f"  note_locations_measure: {batch['note_locations_measure'].shape}")
    print(f"  note_locations_voice: {batch['note_locations_voice'].shape}")
else:
    print("\nWARNING: note_locations not found in batch - hierarchical encoder may not work optimally")

## Step 6: Create Score-Aligned Model

In [None]:
from src.models.score_aligned_module import ScoreAlignedModule

# Create model with pre-trained MIDI encoder
model = ScoreAlignedModule(
    # MIDI Encoder
    midi_hidden_dim=CONFIG['midi_hidden_dim'],
    midi_num_layers=CONFIG['midi_num_layers'],
    midi_num_heads=CONFIG['midi_num_heads'],
    max_seq_length=CONFIG['max_seq_length'],
    # Score Encoder
    score_hidden_dim=CONFIG['score_hidden_dim'],
    score_num_layers=CONFIG['score_num_layers'],
    score_note_features=CONFIG['score_note_features'],
    use_hierarchical_encoder=CONFIG['use_hierarchical_encoder'],
    # Fusion
    fusion_type=CONFIG['fusion_type'],
    fused_dim=CONFIG['fused_dim'],
    # Aggregation
    attention_da=CONFIG['attention_da'],
    attention_r=CONFIG['attention_r'],
    head_hidden_dim=CONFIG['head_hidden_dim'],
    # Training
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    dropout=CONFIG['dropout'],
    freeze_midi_encoder=CONFIG['freeze_midi_encoder'],
    # Pre-trained encoder - loaded at initialization!
    midi_pretrained_checkpoint=CONFIG['pretrained_checkpoint'],
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("="*70)
print("SCORE-ALIGNED MODEL WITH HIERARCHICAL ENCODER")
print("="*70)
print(f"Model class: ScoreAlignedModule (no fallback - fail fast)")
print(f"Pre-trained encoder: {CONFIG['pretrained_checkpoint']}")
print(f"Score encoder: {'Hierarchical (HAN)' if CONFIG['use_hierarchical_encoder'] else 'Flat (Transformer)'}")
print(f"Note features: {CONFIG['score_note_features']} per note")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Fusion type: {CONFIG['fusion_type']}")
print(f"Dimensions: {len(model.dimensions)}")
print("="*70)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Checkpoint callback - monitor mean R-squared
checkpoint_callback = ModelCheckpoint(
    dirpath=CONFIG['checkpoint_dir'],
    filename='score_aligned-{epoch:02d}-{val_mean_r2:.4f}',
    monitor='val/mean_r2',
    mode='max',
    save_top_k=3,
    save_last=True,
)

# Early stopping
early_stopping = EarlyStopping(
    monitor='val/mean_r2',
    patience=CONFIG['early_stopping_patience'],
    mode='max',
)

# LR monitor
lr_monitor = LearningRateMonitor(logging_interval='step')

# Logger
logger = TensorBoardLogger(
    save_dir='/tmp/logs',
    name='score_aligned',
)

# Trainer
# Note: deterministic=True removed because AdaptiveAvgPool1d in TempoCurveEncoder
# doesn't have a deterministic CUDA backward implementation.
# Reproducibility is still ensured via pl.seed_everything(42) in the training cell.
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu',
    devices=1,
    precision=CONFIG['precision'],
    gradient_clip_val=CONFIG['gradient_clip_val'],
    accumulate_grad_batches=CONFIG['accumulate_grad_batches'],
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    logger=logger,
    log_every_n_steps=10,
    val_check_interval=0.5,  # Validate twice per epoch
)

print("Trainer configured!")
print(f"  Precision: {CONFIG['precision']}")
print(f"  Max epochs: {CONFIG['max_epochs']}")
print(f"  Effective batch size: {CONFIG['batch_size'] * CONFIG['accumulate_grad_batches']}")
print(f"  Early stopping patience: {CONFIG['early_stopping_patience']}")

## Step 8: Train!

In [None]:
# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Train
print("Starting training...")
print("\nKey metrics to watch:")
print("  - val/mean_r2: Overall R-squared (target: 0.30-0.40)")
print("  - val/tempo_r2: Tempo dimension (currently -0.15, should improve!)")
print()

trainer.fit(model, train_loader, val_loader)

In [None]:
# Sync checkpoints to Google Drive
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 9: Comprehensive Evaluation

This section provides:
- Per-dimension analysis with statistical tests
- Comparison to SOTA baselines (PercePiano paper)
- Category-level metrics (timing, dynamics, etc.)
- Visualization of results

In [None]:
# Test with best checkpoint
print("Running test with best checkpoint...")
best_path = checkpoint_callback.best_model_path
print(f"Best checkpoint: {best_path}")

if best_path:
    test_results = trainer.test(model, test_loader, ckpt_path=best_path)
    print("\nTest Results:")
    for k, v in test_results[0].items():
        print(f"  {k}: {v:.4f}")

In [None]:
import torch
import numpy as np

# Load best model
from src.models.score_aligned_module import ScoreAlignedModuleWithFallback
best_model = ScoreAlignedModuleWithFallback.load_from_checkpoint(checkpoint_callback.best_model_path)
best_model.eval()
best_model.cuda()

# Helper to get note_locations from batch
def get_note_locations(batch):
    if 'note_locations_beat' in batch:
        return {
            'beat': batch['note_locations_beat'],
            'measure': batch['note_locations_measure'],
            'voice': batch['note_locations_voice'],
        }
    return None

# Collect predictions on test set
all_preds = []
all_targets = []

print("Collecting predictions on test set...")
with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        note_locations = get_note_locations(batch)
        outputs = best_model(
            batch['midi_tokens'],
            batch['score_note_features'],
            batch['score_global_features'],
            batch['score_tempo_curve'],
            batch.get('midi_attention_mask'),
            batch.get('score_attention_mask'),
            note_locations=note_locations,
        )
        all_preds.append(outputs['predictions'].cpu())
        all_targets.append(batch['scores'].cpu())

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()
dimensions = best_model.dimensions

print(f"Collected {len(all_preds)} test samples across {len(dimensions)} dimensions")

import torch
import numpy as np

# Load best model (using ScoreAlignedModule, not fallback version)
from src.models.score_aligned_module import ScoreAlignedModule
best_model = ScoreAlignedModule.load_from_checkpoint(checkpoint_callback.best_model_path)
best_model.eval()
best_model.cuda()

# Helper to get note_locations from batch
def get_note_locations(batch):
    if 'note_locations_beat' in batch:
        return {
            'beat': batch['note_locations_beat'],
            'measure': batch['note_locations_measure'],
            'voice': batch['note_locations_voice'],
        }
    return None

# Collect predictions on test set
all_preds = []
all_targets = []

print("Collecting predictions on test set...")
with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        note_locations = get_note_locations(batch)
        outputs = best_model(
            batch['midi_tokens'],
            batch['score_note_features'],
            batch['score_global_features'],
            batch['score_tempo_curve'],
            batch.get('midi_attention_mask'),
            batch.get('score_attention_mask'),
            note_locations=note_locations,
        )
        all_preds.append(outputs['predictions'].cpu())
        all_targets.append(batch['scores'].cpu())

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()
dimensions = best_model.dimensions

print(f"Collected {len(all_preds)} test samples across {len(dimensions)} dimensions")

In [None]:
from src.evaluation import (
    compute_all_metrics,
    PerDimensionAnalysis,
    compare_to_sota,
    format_comparison_table,
    create_results_table,
    PERCEPIANO_BASELINES,
    DIMENSION_BASELINES,
)

# Compute all metrics
metrics = compute_all_metrics(
    predictions=all_preds,
    targets=all_targets,
    dimension_names=list(dimensions),
)

# Print results table
print(create_results_table(metrics))

### 9.2 Per-Dimension Analysis

Detailed analysis of which dimensions are working well and which need improvement.

In [None]:
# Comprehensive per-dimension analysis
analysis = PerDimensionAnalysis.from_predictions(
    predictions=all_preds,
    targets=all_targets,
    dimension_names=list(dimensions),
    weak_threshold=0.15,   # R^2 below this = needs improvement
    strong_threshold=0.30, # R^2 above this = working well
)

print(analysis.format_report())

### 9.3 SOTA Comparison

Compare our model to published baselines from the PercePiano paper:
- **Bi-LSTM**: R^2 = 0.185 (baseline)
- **MidiBERT**: R^2 = 0.313 (pretrained MIDI encoder)
- **Bi-LSTM + SA + HAN**: R^2 = 0.397 (best published result with score alignment)

In [None]:
# Get our overall R^2
our_r2 = metrics['r2'].value

# Get per-dimension R^2 for detailed comparison
per_dim_r2 = metrics['r2'].per_dimension

# Compare to SOTA
comparison = compare_to_sota(
    model_r2=our_r2,
    model_name="CrescendAI (Score-Aligned)",
    split_type="piece",  # PercePiano uses piece-split by default
    per_dimension_r2=per_dim_r2,
)

print(format_comparison_table(comparison))

### 9.4 Visualization

Visualize per-dimension results and error distributions.

In [None]:
import matplotlib.pyplot as plt
from src.evaluation import (
    plot_per_dimension_results,
    plot_prediction_scatter,
    plot_error_distribution,
)

# Per-dimension R^2 bar chart with SOTA baselines
fig1 = plot_per_dimension_results(
    dimension_results=per_dim_r2,
    metric_name="R^2",
    baseline_values=DIMENSION_BASELINES,
    title="Per-Dimension R^2 (vs SOTA Baselines)",
    figsize=(14, 6),
)
plt.show()

In [None]:
# Prediction scatter plots for key dimensions
key_dims = ['tempo', 'timing', 'dynamic_range', 'interpretation', 'articulation_length', 'mood_energy']
key_indices = [list(dimensions).index(d) for d in key_dims if d in dimensions]

fig2 = plot_prediction_scatter(
    predictions=all_preds[:, key_indices],
    targets=all_targets[:, key_indices],
    dimension_names=[dimensions[i] for i in key_indices],
    n_cols=3,
    figsize=(12, 8),
)
plt.suptitle("Prediction vs Target (Key Dimensions)", y=1.02, fontsize=14)
plt.show()

In [None]:
# Error distribution by dimension
fig3 = plot_error_distribution(
    predictions=all_preds,
    targets=all_targets,
    dimension_names=list(dimensions),
    figsize=(16, 6),
)
plt.show()

### 9.5 Key Findings Summary

In [None]:
# Summary of key findings
print("="*70)
print("KEY FINDINGS SUMMARY")
print("="*70)

# 1. Overall performance
print(f"\n1. OVERALL PERFORMANCE")
print(f"   Mean R^2: {our_r2:.4f}")
print(f"   Target (0.30-0.40): {'ACHIEVED' if our_r2 >= 0.30 else 'NOT YET'}")
print(f"   Ranking vs SOTA: {comparison['rank']}/{comparison['total_baselines']}")

# 2. Tempo dimension (key target)
tempo_r2 = per_dim_r2.get('tempo', 0)
print(f"\n2. TEMPO DIMENSION (Primary Target)")
print(f"   Previous (MIDI-only): R^2 ~ -0.15")
print(f"   Current (Score-aligned): R^2 = {tempo_r2:.4f}")
print(f"   Improvement: {'YES' if tempo_r2 > 0 else 'NEEDS WORK'}")

# 3. Strong dimensions
print(f"\n3. STRONG DIMENSIONS (R^2 > 0.30)")
strong = analysis.get_ranked_dimensions('r2')[:5]
for dim, r2 in strong:
    if r2 > 0.30:
        print(f"   {dim}: {r2:.4f}")

# 4. Weak dimensions needing improvement
print(f"\n4. DIMENSIONS NEEDING IMPROVEMENT (R^2 < 0.15)")
weak = analysis.get_ranked_dimensions('r2', ascending=True)[:5]
for dim, r2 in weak:
    if r2 < 0.15:
        print(f"   {dim}: {r2:.4f}")

# 5. Score alignment benefit
if 'improvement_vs_midi_only' in comparison:
    print(f"\n5. SCORE ALIGNMENT BENEFIT")
    print(f"   vs Best MIDI-only baseline: {comparison['improvement_vs_midi_only']:+.4f}")

# 6. Category-level performance
print(f"\n6. CATEGORY PERFORMANCE")
for cat, cat_metrics in sorted(analysis.category_metrics.items(), 
                                key=lambda x: x[1]['r2'], reverse=True):
    print(f"   {cat}: R^2={cat_metrics['r2']:.4f}")

print("="*70)

## Step 10: Save and Sync Final Model

In [None]:
import torch
from pathlib import Path

# Save final model with comprehensive evaluation results
final_path = Path(CONFIG['checkpoint_dir']) / 'score_aligned_final.pt'
torch.save({
    'state_dict': best_model.state_dict(),
    'hparams': dict(best_model.hparams),
    'dimensions': list(dimensions),
    'metrics': {
        'r2': our_r2,
        'per_dimension_r2': per_dim_r2,
        'mse': metrics['mse'].value,
        'mae': metrics['mae'].value,
        'pearson_r': metrics['pearson_r'].value,
        'spearman_rho': metrics['spearman_rho'].value,
    },
    'sota_comparison': {
        'rank': comparison['rank'],
        'total_baselines': comparison['total_baselines'],
        'vs_best_baseline': comparison['improvement_vs_best'],
    },
    'category_metrics': analysis.category_metrics,
    'strong_dimensions': analysis.strong_dimensions,
    'weak_dimensions': analysis.weak_dimensions,
}, final_path)
print(f"Saved final model to {final_path}")

# Final sync to Google Drive
if RCLONE_AVAILABLE:
    print("\nFinal sync to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")
    print(f"Checkpoints available at: {CONFIG['gdrive_checkpoint']}")