In [None]:
print("="*70)
print("EXTRACTING MAESTRO WITH VARIANCE DATASET")
print("="*70)

import os
import tarfile
from pathlib import Path

# Extract tar.gz to /tmp/ (local SSD for fast access)
tarball_path = Path("/content/maestro_with_variance.tar.gz")  # Adjust if uploaded elsewhere

if not tarball_path.exists():
    print(f"✗ ERROR: {tarball_path} not found!")
    print("\nPlease upload maestro_with_variance.tar.gz to /content/")
    print("You can create it by running:")
    print("  python scripts/prepare_maestro_for_upload.py --maestro_zip ~/Downloads/maestro-v3.0.0.zip")
    raise FileNotFoundError(f"{tarball_path} not found")

extract_dir = Path("/tmp/maestro_data")
extract_dir.mkdir(parents=True, exist_ok=True)

print(f"Extracting {tarball_path.name}...")
print(f"Size: {tarball_path.stat().st_size / (1024**3):.2f} GB")
print(f"Destination: {extract_dir}\n")

with tarfile.open(tarball_path, "r:gz") as tar:
    tar.extractall(extract_dir)

print("✓ Extraction complete!")
print(f"\nDataset structure:")
print(f"  Audio: {extract_dir}/audio/")
print(f"  MIDI: {extract_dir}/midi/")
print(f"  Annotations: {extract_dir}/annotations/")

# Verify files
audio_files = list((extract_dir / "audio").glob("*.wav"))
midi_files = list((extract_dir / "midi").glob("*.mid"))
annotation_files = list((extract_dir / "annotations").glob("*.jsonl"))

print(f"\nFiles extracted:")
print(f"  Audio files: {len(audio_files):,}")
print(f"  MIDI files: {len(midi_files):,}")
print(f"  Annotation files: {len(annotation_files)}")

# Read sample annotation to verify
if annotation_files:
    import json
    with open(annotation_files[0]) as f:
        sample = json.loads(f.readline())
    print(f"\nSample annotation:")
    print(f"  Dimensions: {list(sample['labels'].keys())}")
    print(f"  Quality tier: {sample.get('quality_tier', 'N/A')}")
    print(f"  Quality score: {sample.get('quality_score', 'N/A')}")

print("\n" + "="*70)
print("✓ DATASET READY FOR TRAINING")
print("="*70)

# Piano Performance Evaluation - TRAINING_PLAN_v2.md Phase 2

Trains 3 models with controlled quality variance:
1. Audio-Only (MERT only)
2. MIDI-Only (MIDIBert only)
3. Fusion (MERT + MIDIBert)

**Updates from v1**:
- Dimensions: 8 (6→8: added musical_expression, overall_interpretation)
- Quality variance: 4 tiers (Pristine/Good/Moderate/Poor)
- Diagnostics: Attention entropy, cross-modal alignment, quality tier analysis
- Sample size: ~450K training samples (4x quality tiers)

**Expected time**: 4-5 hours training + 15-20 min setup
**Goal**: Validate fusion architecture learns quality (not complexity) before Phase 3

In [None]:
import torch
print(torch.cuda.is_available())

In [None]:
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installation status unknown"

In [None]:
%pip install -q huggingface_hub

import os
os.environ.pop("HF_TOKEN", None)
os.environ.pop("HUGGINGFACEHUB_API_TOKEN", None)

from huggingface_hub import login, HfApi

try:
    import getpass as gp
    raw = gp.getpass("Paste your Hugging Face token (input hidden): ")
    token = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
    if not isinstance(token, str):
        raise TypeError(f"Unexpected token type: {type(token).__name__}")
    token = token.strip()
    if not token:
        raise ValueError("Empty token provided")
    login(token=token, add_to_git_credential=False)
    who = HfApi().whoami(token=token)
    print(f"✓ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
except Exception as e:
    print(f"[HF Login] getpass flow failed: {e}")
    print("Falling back to interactive login widget...")
    login()
    try:
        who = HfApi().whoami()
        print(f"✓ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    except Exception as e2:
        print(f"[HF Login] Verification skipped: {e2}")

In [None]:
import os
from pathlib import Path

print("\n" + "="*70)
print("COPYING MODELS AND CHECKPOINTS FROM GOOGLE DRIVE")
print("="*70)

# Check if rclone is configured
import subprocess
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    print("\n⚠️  rclone not configured!")
    print("Run 'rclone config' in terminal to set up 'gdrive' remote")
    print("Follow the OAuth flow for remote server authentication")
    raise RuntimeError("rclone gdrive remote not configured")

# 1. Copy MERT model from Google Drive
print("\n" + "="*70)
print("1. MERT-95M MODEL")
print("="*70)

HF_CACHE_ROOT = os.path.expanduser("~/.cache/huggingface/hub")
MERT_CACHE_DIR = os.path.join(HF_CACHE_ROOT, "models--m-a-p--MERT-v1-95M")
MERT_REFS_DIR = os.path.join(MERT_CACHE_DIR, "refs")
MERT_SNAPSHOTS_DIR = os.path.join(MERT_CACHE_DIR, "snapshots")
MERT_SNAPSHOT_MAIN = os.path.join(MERT_SNAPSHOTS_DIR, "main")

if os.path.exists(MERT_SNAPSHOT_MAIN) and os.listdir(MERT_SNAPSHOT_MAIN):
    print(f"✓ MERT-95M already cached at: {MERT_CACHE_DIR}")
    print(f"\nCached files:")
    !ls -lh {MERT_SNAPSHOT_MAIN}/
else:
    print("Copying MERT-95M from Google Drive (~380MB)...")
    
    # Create directory structure
    os.makedirs(MERT_SNAPSHOT_MAIN, exist_ok=True)
    os.makedirs(MERT_REFS_DIR, exist_ok=True)
    
    # Copy model files to snapshot directory
    print("\nCopying model files...")
    !rclone copy gdrive:MERT-v1-95M/ {MERT_SNAPSHOT_MAIN}/ -P --transfers 4
    
    # Create refs/main file pointing to the snapshot
    with open(os.path.join(MERT_REFS_DIR, "main"), 'w') as f:
        f.write("main")
    
    # Create a minimal .no_exist marker file
    Path(MERT_CACHE_DIR, ".no_exist").touch()
    
    print("\n✓ MERT-95M copied and cached")
    print(f"   Cache location: {MERT_CACHE_DIR}")
    print(f"   Snapshot: {MERT_SNAPSHOT_MAIN}")
    
    # List what was copied
    print("\nCopied files:")
    !ls -lh {MERT_SNAPSHOT_MAIN}/

# 2. Copy training checkpoints from Google Drive
print("\n" + "="*70)
print("2. TRAINING CHECKPOINTS")
print("="*70)

CHECKPOINT_ROOT = '/tmp/crescendai_checkpoints'
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)

print("\nCopying checkpoints from Google Drive...")
print("This may take a few minutes depending on checkpoint size...\n")

for mode in ['audio_full', 'midi_full', 'fusion_full']:
    print(f"Copying {mode}...")
    !rclone copy gdrive:crescendai_checkpoints/{mode} {CHECKPOINT_ROOT}/{mode} -P --transfers 4

print("\n" + "="*70)
print("✓ ALL FILES COPIED FROM GOOGLE DRIVE")
print("="*70)

# Verify what was copied
print("\nCheckpoint contents:")
!ls -lh {CHECKPOINT_ROOT}/*/*.ckpt 2>/dev/null || echo "No .ckpt files found"

print(f"\nMERT model cache: {MERT_CACHE_DIR}")
print(f"MERT snapshot: {MERT_SNAPSHOT_MAIN}")

In [None]:
%cd ~

!rm -rf /tmp/crescendai
!git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai
%cd /tmp/crescendai/model
!git log -1 --oneline

In [None]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add to PATH for this session
import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

print("\n✓ uv installed")

In [None]:
!uv pip install --system -e .

# Install optional GPU dependencies
!uv pip install --system nnAudio torchcodec

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")
print("✓ Dependencies installed")

!python scripts/setup_colab_environment.py

In [None]:
print("="*70)
print("VERIFYING MERT-95M MODEL")
print("="*70)

from transformers import AutoModel, Wav2Vec2FeatureExtractor
import torch
import os

# Use the snapshot directory directly instead of relying on HF cache resolution
MERT_SNAPSHOT_DIR = os.path.join(
    os.path.expanduser("~/.cache/huggingface/hub"),
    "models--m-a-p--MERT-v1-95M/snapshots/main"
)

print(f"\nChecking for model files in: {MERT_SNAPSHOT_DIR}")
if not os.path.exists(MERT_SNAPSHOT_DIR):
    raise RuntimeError(f"MERT snapshot directory not found: {MERT_SNAPSHOT_DIR}")

print("Model files:")
!ls -lh {MERT_SNAPSHOT_DIR}/

try:
    print("\nLoading model from local snapshot...")
    # Load directly from the snapshot directory (bypasses HF cache lookup)
    model = AutoModel.from_pretrained(
        MERT_SNAPSHOT_DIR,  # Use directory path directly
        trust_remote_code=True,
        local_files_only=True
    )
    
    print("Loading feature extractor from local snapshot...")
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        MERT_SNAPSHOT_DIR,  # Use directory path directly
        trust_remote_code=True,
        local_files_only=True
    )
    
    print(f"\n✓ Model type: {model.config.model_type}")
    print(f"✓ Model loaded from: {MERT_SNAPSHOT_DIR}")
    print("✓ Feature extractor loaded")
    
    # Clean up
    del model
    del processor
    torch.cuda.empty_cache()
    
    print("\n" + "="*70)
    print("✓ MERT-95M VERIFIED AND READY")
    print("="*70)
    
except Exception as e:
    print(f"\n✗ Verification failed: {e}")
    print("\nDebug info:")
    print(f"  Snapshot dir exists: {os.path.exists(MERT_SNAPSHOT_DIR)}")
    if os.path.exists(MERT_SNAPSHOT_DIR):
        print(f"  Files in snapshot: {os.listdir(MERT_SNAPSHOT_DIR)}")
    print("\nMake sure cell-3 completed successfully!")
    raise

## STEP 2: EXTRACT UPLOADED DATASET

Extract maestro_with_variance.tar.gz (uploaded to runtime)

In [None]:
print("Verifying data paths...")

import json
from pathlib import Path

# Updated paths for extracted dataset
train_path = '/tmp/maestro_data/annotations/train.jsonl'
val_path = '/tmp/maestro_data/annotations/val.jsonl'
test_path = '/tmp/maestro_data/annotations/test.jsonl'

# Check a sample annotation
with open(train_path) as f:
    sample = json.loads(f.readline())
    
print(f"\nSample annotation:")
print(f"  Audio: {sample['audio_path']}")
print(f"  MIDI:  {sample['midi_path']}")
print(f"  Quality tier: {sample.get('quality_tier', 'N/A')}")
print(f"  Quality score: {sample.get('quality_score', 'N/A'):.1f}")

# Verify files exist
audio_exists = Path(sample['audio_path']).exists()
midi_exists = Path(sample['midi_path']).exists() if sample['midi_path'] else False

print(f"\nFile existence check:")
print(f"  Audio exists: {'✓' if audio_exists else '✗'}")
print(f"  MIDI exists:  {'✓' if midi_exists else '✗ (may be OK if path is None)'}")

if not audio_exists:
    print(f"\n⚠️  WARNING: Audio file not found!")
    print(f"     Check that data extraction completed correctly")
    print(f"     Expected: {sample['audio_path']}")
elif not midi_exists and sample['midi_path']:
    print(f"\n⚠️  WARNING: MIDI file not found!")
    print(f"     Expected: {sample['midi_path']}")
else:
    print(f"\n✓ Sample files verified - data structure looks correct!")

# Preflight Check
print("="*70)
print("STEP 3: PREFLIGHT CHECK")
print("="*70)
print("\nVerifying training environment and data with 8 dimensions...\n")

# Updated dimension list (8 dimensions from TRAINING_PLAN_v2.md)
DIMENSIONS = [
    'note_accuracy',
    'rhythmic_stability', 
    'articulation_clarity',
    'pedal_technique',
    'tone_quality',
    'dynamic_range',
    'musical_expression',
    'overall_interpretation'
]

print(f"Dimensions ({len(DIMENSIONS)}): {DIMENSIONS}")

# Verify sample has all dimensions
missing_dims = [d for d in DIMENSIONS if d not in sample['labels']]
if missing_dims:
    print(f"\n⚠️  WARNING: Sample missing dimensions: {missing_dims}")
    print("This may indicate the dataset was created with old labeling functions")
else:
    print("\n✓ All 8 dimensions present in annotations")

print(f"\nAnnotation paths:")
print(f"  Train: {train_path}")
print(f"  Val:   {val_path}")  
print(f"  Test:  {test_path}")

In [None]:
print("="*70)
print("STEP 2.5: FIX ANNOTATION PATHS")
print("="*70)
print("\nUpdating annotation files to use local SSD paths...\n")

!python scripts/fix_annotation_paths.py

print("\n✓ Annotation paths updated for local SSD access")

In [None]:
import warnings
warnings.filterwarnings('ignore', message='divide by zero')
warnings.filterwarnings('ignore', category=SyntaxWarning)  # pydub regex warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchaudio')
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

%%time
# Updated to use 8 dimensions and new paths
!python train.py \
    --train-path /tmp/maestro_data/annotations/train.jsonl \
    --val-path /tmp/maestro_data/annotations/val.jsonl \
    --test-path /tmp/maestro_data/annotations/test.jsonl \
    --dimensions note_accuracy rhythmic_stability articulation_clarity pedal_technique tone_quality dynamic_range musical_expression overall_interpretation \
    --mode audio \
    --epochs 5 \
    --batch-size 16 \
    --learning-rate 3e-5 \
    --checkpoint-dir /content/drive/MyDrive/crescendai_checkpoints/audio_full

## Experiment 1: Audio-Only

Training with audio features only (MERT-95M encoder)

In [None]:
import warnings
warnings.filterwarnings('ignore', message='divide by zero')
warnings.filterwarnings('ignore', category=SyntaxWarning)  # pydub regex warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchaudio')
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

%%time
# Updated to use 8 dimensions and new paths
!python train.py \
    --train-path /tmp/maestro_data/annotations/train.jsonl \
    --val-path /tmp/maestro_data/annotations/val.jsonl \
    --test-path /tmp/maestro_data/annotations/test.jsonl \
    --dimensions note_accuracy rhythmic_stability articulation_clarity pedal_technique tone_quality dynamic_range musical_expression overall_interpretation \
    --mode midi \
    --epochs 5 \
    --batch-size 16 \
    --learning-rate 3e-5 \
    --checkpoint-dir /content/drive/MyDrive/crescendai_checkpoints/midi_full

## Experiment 2: MIDI-Only

Training with MIDI features only (MIDIBert encoder)

In [None]:
import warnings
warnings.filterwarnings('ignore', message='divide by zero')
warnings.filterwarnings('ignore', category=SyntaxWarning)  # pydub regex warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchaudio')
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

%%time
# Updated to use 8 dimensions and new paths
!python train.py \
    --train-path /tmp/maestro_data/annotations/train.jsonl \
    --val-path /tmp/maestro_data/annotations/val.jsonl \
    --test-path /tmp/maestro_data/annotations/test.jsonl \
    --dimensions note_accuracy rhythmic_stability articulation_clarity pedal_technique tone_quality dynamic_range musical_expression overall_interpretation \
    --mode fusion \
    --epochs 5 \
    --batch-size 16 \
    --learning-rate 3e-5 \
    --checkpoint-dir /content/drive/MyDrive/crescendai_checkpoints/fusion_full

## Experiment 3: Fusion

Training with both audio and MIDI features (multi-modal fusion)

In [None]:
import pytorch_lightning as pl
from src.models.lightning_module import PerformanceEvaluationModel
from src.data.dataset import create_dataloaders
from pathlib import Path
import numpy as np
import torch
from scipy import stats

print("="*80)
print("COMPREHENSIVE 3-WAY MODEL EVALUATION - TRAINING_PLAN_v2.md")
print("="*80)

# Updated dimensions (8 from v2)
dimensions = [
    'note_accuracy',
    'rhythmic_stability',
    'articulation_clarity', 
    'pedal_technique',
    'tone_quality',
    'dynamic_range',
    'musical_expression',
    'overall_interpretation'
]

# Updated paths
train_path = '/tmp/maestro_data/annotations/train.jsonl'
val_path = '/tmp/maestro_data/annotations/val.jsonl'
test_path = '/tmp/maestro_data/annotations/test.jsonl'

# Load all 3 models
CHECKPOINT_ROOT = '/content/drive/MyDrive/crescendai_checkpoints'
models = {}
for mode in ['audio', 'midi', 'fusion']:
    ckpt_dir = Path(f'{CHECKPOINT_ROOT}/{mode}_full')
    ckpts = list(ckpt_dir.glob('*.ckpt'))
    if ckpts:
        latest = sorted(ckpts)[-1]
        print(f"Loading {mode}: {latest.name}")
        models[mode] = PerformanceEvaluationModel.load_from_checkpoint(str(latest))
        models[mode].eval()
        models[mode] = models[mode].cuda()
    else:
        print(f"⚠️  No checkpoint found for {mode}")

print(f"\nLoaded {len(models)}/3 models")

# Create test dataloader
_, _, test_loader = create_dataloaders(
    train_annotation_path=train_path,
    val_annotation_path=val_path,
    test_annotation_path=test_path,
    dimension_names=dimensions,
    batch_size=8,
    num_workers=4,
    augmentation_config=None,
    audio_sample_rate=24000,
    max_audio_length=240000,
    max_midi_events=512,
)

print(f"Test set size: {len(test_loader.dataset)} samples")

# Evaluate each model
trainer = pl.Trainer(accelerator='auto', devices='auto', precision=16)
results = {}
predictions = {}

for mode, model in models.items():
    print(f"\nEvaluating {mode}...")
    test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
    results[mode] = test_results[0]
    
    # Collect predictions for deeper analysis
    model = model.cuda()
    model.eval()
    all_preds = []
    all_targets = []
    all_quality_tiers = []
    
    with torch.no_grad():
        for batch in test_loader:
            audio_waveform = batch['audio_waveform'].cuda()
            midi_tokens = batch.get('midi_tokens', None)
            if midi_tokens is not None:
                midi_tokens = midi_tokens.cuda()
            targets = batch['labels'].cuda()
            
            # Forward pass
            output = model(
                audio_waveform=audio_waveform,
                midi_tokens=midi_tokens,
            )
            
            if output is None:
                continue
                
            preds = output['scores']
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
            
            # Track quality tiers if available
            if 'quality_tier' in batch:
                all_quality_tiers.extend(batch['quality_tier'])
    
    predictions[mode] = {
        'preds': np.concatenate(all_preds, axis=0),
        'targets': np.concatenate(all_targets, axis=0),
        'quality_tiers': all_quality_tiers if all_quality_tiers else None
    }
    print(f"  Collected {len(predictions[mode]['preds'])} predictions")

# Find common sample count
min_samples = min(len(predictions['audio']['preds']), 
                  len(predictions['midi']['preds']), 
                  len(predictions['fusion']['preds']))
print(f"\nUsing {min_samples} samples for analysis")

print("\n" + "="*80)
print("1. PER-DIMENSION CORRELATION (Pearson r)")
print("="*80)
print(f"{'Dimension':<28} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Best':<12} {'Improvement'}")
print("-"*80)

for dim_idx, dim in enumerate(dimensions):
    audio_r = results.get('audio', {}).get(f'test_pearson_{dim}', 0)
    midi_r = results.get('midi', {}).get(f'test_pearson_{dim}', 0)
    fusion_r = results.get('fusion', {}).get(f'test_pearson_{dim}', 0)
    
    best_single = max(audio_r, midi_r)
    best_modality = 'Audio' if audio_r >= midi_r else 'MIDI'
    improvement = ((fusion_r - best_single) / best_single * 100) if best_single > 0 else 0
    
    print(f"{dim:<28} {audio_r:>11.3f} {midi_r:>11.3f} {fusion_r:>11.3f} {best_modality:<12} {improvement:>+6.1f}%")

print("-"*80)

print("\n" + "="*80)
print("2. OVERALL PERFORMANCE")
print("="*80)

audio_mean_r = np.mean([results['audio'][f'test_pearson_{d}'] for d in dimensions])
midi_mean_r = np.mean([results['midi'][f'test_pearson_{d}'] for d in dimensions])
fusion_mean_r = np.mean([results['fusion'][f'test_pearson_{d}'] for d in dimensions])

print(f"{'Metric':<35} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Winner'}")
print("-"*80)
best_r = max(audio_mean_r, midi_mean_r, fusion_mean_r)
winner_r = 'Audio' if audio_mean_r == best_r else ('MIDI' if midi_mean_r == best_r else 'Fusion')
print(f"{'Mean Pearson Correlation':<35} {audio_mean_r:>11.3f} {midi_mean_r:>11.3f} {fusion_mean_r:>11.3f} {winner_r}")

print("\n" + "="*80)
print("3. PHASE 2 SUCCESS CRITERIA (TRAINING_PLAN_v2.md)")
print("="*80)

best_single_r = max(audio_mean_r, midi_mean_r)
fusion_improvement = ((fusion_mean_r - best_single_r) / best_single_r * 100) if best_single_r > 0 else 0

print(f"\nFusion improvement over best single-modal: {fusion_improvement:+.1f}%")
print(f"Target: ≥10% improvement")

if fusion_improvement >= 10:
    print("\n✓ PASS: Fusion beats single-modal by ≥10%")
    print("→ GO TO PHASE 3: Proceed with contrastive pre-training")
else:
    print("\n✗ FAIL: Fusion improvement < 10% threshold")
    print("→ NO-GO: Debug fusion architecture before Phase 3")

# Display diagnostics if available
if 'val_attention_entropy' in results.get('fusion', {}):
    print("\n" + "="*80)
    print("4. FUSION DIAGNOSTICS")
    print("="*80)
    
    diag = results['fusion']
    print(f"Attention Entropy:       {diag.get('val_attention_entropy', 'N/A')}")
    print(f"Attention Sparsity:      {diag.get('val_attention_sparsity', 'N/A')}")
    print(f"Cross-Modal Alignment:   {diag.get('val_cross_modal_alignment', 'N/A')}")
    print(f"Audio Feature Diversity: {diag.get('val_audio_feature_diversity', 'N/A')}")
    print(f"MIDI Feature Diversity:  {diag.get('val_midi_feature_diversity', 'N/A')}")

print("\n" + "="*80)
print("EVALUATION COMPLETE")
print("="*80)

## Compare Results

Load all 3 trained models and compare performance on test set

In [None]:
import pytorch_lightning as pl
from src.models.lightning_module import PerformanceEvaluationModel
from src.data.dataset import create_dataloaders
from pathlib import Path
import numpy as np
import torch
from scipy import stats

print("="*80)
print("COMPREHENSIVE 3-WAY MODEL EVALUATION")
print("="*80)

# Load all 3 models (use CHECKPOINT_ROOT from earlier cell)
models = {}
for mode in ['audio', 'midi', 'fusion']:
    ckpt_dir = Path(f'{CHECKPOINT_ROOT}/{mode}_full')
    ckpts = list(ckpt_dir.glob('*.ckpt'))
    if ckpts:
        latest = sorted(ckpts)[-1]
        print(f"Loading {mode}: {latest.name}")
        models[mode] = PerformanceEvaluationModel.load_from_checkpoint(str(latest))
        models[mode].eval()
        models[mode] = models[mode].cuda()
    else:
        print(f"⚠️  No checkpoint found for {mode}")

print(f"\nLoaded {len(models)}/3 models")

# Create test dataloader (using local SSD paths)
_, _, test_loader = create_dataloaders(
    train_annotation_path='/tmp/crescendai_data/data/annotations/synthetic_train.jsonl',
    val_annotation_path='/tmp/crescendai_data/data/annotations/synthetic_val.jsonl',
    test_annotation_path='/tmp/crescendai_data/data/annotations/synthetic_test.jsonl',
    dimension_names=['note_accuracy', 'rhythmic_precision', 'tone_quality', 'dynamics_control', 'articulation', 'pedaling'],
    batch_size=8,
    num_workers=4,
    augmentation_config=None,
    audio_sample_rate=24000,
    max_audio_length=240000,
    max_midi_events=512,
)

print(f"Test set size: {len(test_loader.dataset)} samples")

# Evaluate each model
trainer = pl.Trainer(accelerator='auto', devices='auto', precision=16)
results = {}
predictions = {}

for mode, model in models.items():
    print(f"\nEvaluating {mode}...")
    test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
    results[mode] = test_results[0]
    
    # Collect predictions for deeper analysis
    model = model.cuda()  # Ensure model is on GPU
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in test_loader:
            # Move batch to GPU
            audio_waveform = batch['audio_waveform'].cuda()
            midi_tokens = batch.get('midi_tokens', None)
            if midi_tokens is not None:
                midi_tokens = midi_tokens.cuda()
            targets = batch['labels'].cuda()
            
            # Forward pass with proper arguments
            output = model(
                audio_waveform=audio_waveform,
                midi_tokens=midi_tokens,
            )
            
            # Skip if batch was None (all MIDI failed in MIDI-only mode)
            if output is None:
                continue
                
            preds = output['scores']
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    predictions[mode] = {
        'preds': np.concatenate(all_preds, axis=0),
        'targets': np.concatenate(all_targets, axis=0)
    }
    print(f"  Collected {len(predictions[mode]['preds'])} predictions")

dimensions = ['note_accuracy', 'rhythmic_precision', 'tone_quality', 'dynamics_control', 'articulation', 'pedaling']

# Find common sample count (some batches may have been skipped in MIDI mode)
min_samples = min(len(predictions['audio']['preds']), 
                  len(predictions['midi']['preds']), 
                  len(predictions['fusion']['preds']))
print(f"\nUsing {min_samples} samples for analysis (minimum across all models)")

print("\n" + "="*80)
print("1. PER-DIMENSION CORRELATION (Pearson r)")
print("="*80)
print(f"{'Dimension':<25} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Best':<12} {'Improvement'}")
print("-"*80)

for dim_idx, dim in enumerate(dimensions):
    audio_r = results.get('audio', {}).get(f'test_pearson_{dim}', 0)
    midi_r = results.get('midi', {}).get(f'test_pearson_{dim}', 0)
    fusion_r = results.get('fusion', {}).get(f'test_pearson_{dim}', 0)
    
    best_single = max(audio_r, midi_r)
    best_modality = 'Audio' if audio_r >= midi_r else 'MIDI'
    improvement = ((fusion_r - best_single) / best_single * 100) if best_single > 0 else 0
    
    print(f"{dim:<25} {audio_r:>11.3f} {midi_r:>11.3f} {fusion_r:>11.3f} {best_modality:<12} {improvement:>+6.1f}%")

print("-"*80)

print("\n" + "="*80)
print("2. PER-DIMENSION MAE (Mean Absolute Error, 0-100 scale)")
print("="*80)
print(f"{'Dimension':<25} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Best':<12} {'Reduction'}")
print("-"*80)

for dim_idx, dim in enumerate(dimensions):
    audio_mae = np.mean(np.abs(predictions['audio']['preds'][:min_samples, dim_idx] - predictions['audio']['targets'][:min_samples, dim_idx]))
    midi_mae = np.mean(np.abs(predictions['midi']['preds'][:min_samples, dim_idx] - predictions['midi']['targets'][:min_samples, dim_idx]))
    fusion_mae = np.mean(np.abs(predictions['fusion']['preds'][:min_samples, dim_idx] - predictions['fusion']['targets'][:min_samples, dim_idx]))
    
    best_single_mae = min(audio_mae, midi_mae)
    best_modality = 'Audio' if audio_mae <= midi_mae else 'MIDI'
    reduction = ((best_single_mae - fusion_mae) / best_single_mae * 100) if best_single_mae > 0 else 0
    
    print(f"{dim:<25} {audio_mae:>11.2f} {midi_mae:>11.2f} {fusion_mae:>11.2f} {best_modality:<12} {reduction:>+6.1f}%")

print("-"*80)

print("\n" + "="*80)
print("3. PER-DIMENSION RMSE (Root Mean Squared Error, 0-100 scale)")
print("="*80)
print(f"{'Dimension':<25} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Best':<12} {'Reduction'}")
print("-"*80)

for dim_idx, dim in enumerate(dimensions):
    audio_rmse = np.sqrt(np.mean((predictions['audio']['preds'][:min_samples, dim_idx] - predictions['audio']['targets'][:min_samples, dim_idx])**2))
    midi_rmse = np.sqrt(np.mean((predictions['midi']['preds'][:min_samples, dim_idx] - predictions['midi']['targets'][:min_samples, dim_idx])**2))
    fusion_rmse = np.sqrt(np.mean((predictions['fusion']['preds'][:min_samples, dim_idx] - predictions['fusion']['targets'][:min_samples, dim_idx])**2))
    
    best_single_rmse = min(audio_rmse, midi_rmse)
    best_modality = 'Audio' if audio_rmse <= midi_rmse else 'MIDI'
    reduction = ((best_single_rmse - fusion_rmse) / best_single_rmse * 100) if best_single_rmse > 0 else 0
    
    print(f"{dim:<25} {audio_rmse:>11.2f} {midi_rmse:>11.2f} {fusion_rmse:>11.2f} {best_modality:<12} {reduction:>+6.1f}%")

print("-"*80)

print("\n" + "="*80)
print("4. OVERALL PERFORMANCE (Averaged Across All Dimensions)")
print("="*80)
print(f"{'Metric':<35} {'Audio':<12} {'MIDI':<12} {'Fusion':<12} {'Winner'}")
print("-"*80)

# Mean Pearson
audio_mean_r = np.mean([results['audio'][f'test_pearson_{d}'] for d in dimensions])
midi_mean_r = np.mean([results['midi'][f'test_pearson_{d}'] for d in dimensions])
fusion_mean_r = np.mean([results['fusion'][f'test_pearson_{d}'] for d in dimensions])
best_r = max(audio_mean_r, midi_mean_r, fusion_mean_r)
winner_r = 'Audio' if audio_mean_r == best_r else ('MIDI' if midi_mean_r == best_r else 'Fusion')
print(f"{'Mean Pearson Correlation':<35} {audio_mean_r:>11.3f} {midi_mean_r:>11.3f} {fusion_mean_r:>11.3f} {winner_r}")

# Mean MAE (using min_samples for fair comparison)
audio_mean_mae = np.mean([np.mean(np.abs(predictions['audio']['preds'][:min_samples, i] - predictions['audio']['targets'][:min_samples, i])) for i in range(len(dimensions))])
midi_mean_mae = np.mean([np.mean(np.abs(predictions['midi']['preds'][:min_samples, i] - predictions['midi']['targets'][:min_samples, i])) for i in range(len(dimensions))])
fusion_mean_mae = np.mean([np.mean(np.abs(predictions['fusion']['preds'][:min_samples, i] - predictions['fusion']['targets'][:min_samples, i])) for i in range(len(dimensions))])
best_mae = min(audio_mean_mae, midi_mean_mae, fusion_mean_mae)
winner_mae = 'Audio' if audio_mean_mae == best_mae else ('MIDI' if midi_mean_mae == best_mae else 'Fusion')
print(f"{'Mean Absolute Error':<35} {audio_mean_mae:>11.2f} {midi_mean_mae:>11.2f} {fusion_mean_mae:>11.2f} {winner_mae}")

# Mean RMSE (using min_samples for fair comparison)
audio_mean_rmse = np.mean([np.sqrt(np.mean((predictions['audio']['preds'][:min_samples, i] - predictions['audio']['targets'][:min_samples, i])**2)) for i in range(len(dimensions))])
midi_mean_rmse = np.mean([np.sqrt(np.mean((predictions['midi']['preds'][:min_samples, i] - predictions['midi']['targets'][:min_samples, i])**2)) for i in range(len(dimensions))])
fusion_mean_rmse = np.mean([np.sqrt(np.mean((predictions['fusion']['preds'][:min_samples, i] - predictions['fusion']['targets'][:min_samples, i])**2)) for i in range(len(dimensions))])
best_rmse = min(audio_mean_rmse, midi_mean_rmse, fusion_mean_rmse)
winner_rmse = 'Audio' if audio_mean_rmse == best_rmse else ('MIDI' if midi_mean_rmse == best_rmse else 'Fusion')
print(f"{'Root Mean Squared Error':<35} {audio_mean_rmse:>11.2f} {midi_mean_rmse:>11.2f} {fusion_mean_rmse:>11.2f} {winner_rmse}")

print("-"*80)

print("\n" + "="*80)
print("5. FUSION PERFORMANCE GAIN")
print("="*80)

best_single_r = max(audio_mean_r, midi_mean_r)
r_improvement = ((fusion_mean_r - best_single_r) / best_single_r * 100) if best_single_r > 0 else 0
print(f"Pearson r improvement:  {r_improvement:+.1f}% over best single-modal")

best_single_mae = min(audio_mean_mae, midi_mean_mae)
mae_reduction = ((best_single_mae - fusion_mean_mae) / best_single_mae * 100) if best_single_mae > 0 else 0
print(f"MAE reduction:          {mae_reduction:+.1f}% over best single-modal")

best_single_rmse = min(audio_mean_rmse, midi_mean_rmse)
rmse_reduction = ((best_single_rmse - fusion_mean_rmse) / best_single_rmse * 100) if best_single_rmse > 0 else 0
print(f"RMSE reduction:         {rmse_reduction:+.1f}% over best single-modal")

print("\n" + "="*80)
print("6. STATISTICAL SIGNIFICANCE (Fusion vs Best Single-Modal)")
print("="*80)
print(f"{'Dimension':<25} {'Best Single':<15} {'p-value':<12} {'Significant?'}")
print("-"*80)

for dim_idx, dim in enumerate(dimensions):
    # Use min_samples to ensure equal length arrays for paired t-test
    audio_errors = np.abs(predictions['audio']['preds'][:min_samples, dim_idx] - 
                         predictions['audio']['targets'][:min_samples, dim_idx])
    midi_errors = np.abs(predictions['midi']['preds'][:min_samples, dim_idx] - 
                        predictions['midi']['targets'][:min_samples, dim_idx])
    fusion_errors = np.abs(predictions['fusion']['preds'][:min_samples, dim_idx] - 
                          predictions['fusion']['targets'][:min_samples, dim_idx])
    
    # Compare fusion vs best single modal (paired t-test on MAE)
    best_single_errors = audio_errors if np.mean(audio_errors) <= np.mean(midi_errors) else midi_errors
    best_single_name = 'Audio' if np.mean(audio_errors) <= np.mean(midi_errors) else 'MIDI'
    
    t_stat, p_value = stats.ttest_rel(best_single_errors, fusion_errors)
    is_significant = p_value < 0.05 and np.mean(fusion_errors) < np.mean(best_single_errors)
    
    print(f"{dim:<25} {best_single_name:<15} {p_value:>11.4f} {'Yes' if is_significant else 'No':>12}")

print("-"*80)

print("\n" + "="*80)
print("7. DIMENSION LEARNABILITY CATEGORIZATION")
print("="*80)

strong = [d for d in dimensions if max(results.get('audio', {}).get(f'test_pearson_{d}', 0), 
                                        results.get('midi', {}).get(f'test_pearson_{d}', 0)) > 0.4]
moderate = [d for d in dimensions if 0.25 <= max(results.get('audio', {}).get(f'test_pearson_{d}', 0),
                                                   results.get('midi', {}).get(f'test_pearson_{d}', 0)) <= 0.4]
weak = [d for d in dimensions if max(results.get('audio', {}).get(f'test_pearson_{d}', 0),
                                      results.get('midi', {}).get(f'test_pearson_{d}', 0)) < 0.25]

print(f"Strong learners (r > 0.4):     {', '.join(strong) if strong else 'None'}")
print(f"Moderate learners (0.25-0.4):  {', '.join(moderate) if moderate else 'None'}")
print(f"Weak learners (r < 0.25):      {', '.join(weak) if weak else 'None'}")

print("\n" + "="*80)
print("8. MVP TARGET ASSESSMENT")
print("="*80)
print("Technical dimension target: r = 0.50-0.65 (Pearson with expert)")
print("Interpretive dimension target: r = 0.35-0.50")
print("MAE target: 10-15 points on 0-100 scale\n")

technical_dims = dimensions  # All 6 are technical in this experiment
technical_r_values = [results.get('fusion', {}).get(f'test_pearson_{d}', 0) for d in technical_dims]
technical_mean_r = np.mean(technical_r_values)

meets_r_target = technical_mean_r >= 0.50
meets_mae_target = fusion_mean_mae <= 15

print(f"Fusion technical r:     {technical_mean_r:.3f} {'(PASS)' if meets_r_target else '(FAIL - below 0.50 target)'}")
print(f"Fusion overall MAE:     {fusion_mean_mae:.2f} {'(PASS)' if meets_mae_target else '(FAIL - above 15 target)'}")

if meets_r_target and meets_mae_target:
    print("\nMVP TARGETS MET - Ready to proceed with expert annotation")
else:
    print("\nMVP TARGETS NOT MET - Consider architecture improvements or data augmentation")

print("\n" + "="*80)
print("9. EXPERT ANNOTATION RECOMMENDATION")
print("="*80)
print(f"Include in expert labels:  {', '.join(strong + moderate)}")
print(f"Consider skipping:         {', '.join(weak)}")
print(f"\nEstimated cost savings: ${len(weak) * 3000:,} by excluding weak dimensions")
print(f"Recommended budget:     ${len(strong + moderate) * 3000:,} for {len(strong + moderate)} dimensions")
print("="*80)