# Piano Performance Evaluation - 3-Way Model Comparison (Colab)

Trains 3 models to prove multi-modal fusion advantage:
1. Audio-Only (MERT only)
2. MIDI-Only (MIDIBert only)
3. Fusion (MERT + MIDIBert)

**Dimensions**: 3 core (note_accuracy, rhythmic_precision, tone_quality)
**Sample size**: 10,000 training samples
**Expected time**: 6-7 hours total (2h + 1.5h + 2.5h)
**Goal**: Prove fusion beats both baselines by 15-20%

## Google Drive Structure

```
MyDrive/
  crescendai_data/
    all_segments/              # Audio segments
      *.wav
      midi_segments/
        *.mid
    annotations/
      synthetic_train_filtered.jsonl    # 91,865 samples
      synthetic_val_filtered.jsonl
      synthetic_test_filtered.jsonl

  crescendai_checkpoints/
    audio_10k/                 # Audio-only checkpoints
    midi_10k/                  # MIDI-only checkpoints  
    fusion_10k/                # Fusion checkpoints
```

## Setup

In [1]:
# HuggingFace Login
import os
os.environ.pop("HF_TOKEN", None)
os.environ.pop("HUGGINGFACEHUB_API_TOKEN", None)

from huggingface_hub import login, HfApi

try:
    import getpass as gp
    raw = gp.getpass("Paste your Hugging Face token (input hidden): ")
    token = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
    if not isinstance(token, str):
        raise TypeError(f"Unexpected token type: {type(token).__name__}")
    token = token.strip()
    if not token:
        raise ValueError("Empty token provided")
    login(token=token, add_to_git_credential=False)
    who = HfApi().whoami(token=token)
    print(f"✓ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
except Exception as e:
    print(f"[HF Login] getpass flow failed: {e}")
    print("Falling back to interactive login widget...")
    login()
    try:
        who = HfApi().whoami()
        print(f"✓ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    except Exception as e2:
        print(f"[HF Login] Verification skipped: {e2}")

✓ Logged in as: Jai-D


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Verify data exists
import os
ANNOTATIONS_ROOT = '/content/drive/MyDrive/crescendai_data/annotations'

required_files = [
    f'{ANNOTATIONS_ROOT}/synthetic_train_filtered.jsonl',
    f'{ANNOTATIONS_ROOT}/synthetic_val_filtered.jsonl',
    f'{ANNOTATIONS_ROOT}/synthetic_test_filtered.jsonl',
]

print("Checking for data files...")
for f in required_files:
    if os.path.exists(f):
        print(f"✓ {os.path.basename(f)}")
    else:
        print(f"✗ MISSING: {f}")
        raise FileNotFoundError(f"Required file not found: {f}")

print("\n✓ All data files present")

In [None]:
# Clone repo
!rm -rf /content/crescendai
!git clone https://github.com/Jai-Dhiman/crescendai.git /content/crescendai
%cd /content/crescendai/model
!git log -1 --oneline

In [None]:
# Install uv (fast Python package manager)
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add to PATH for this session
import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

print("\n✓ uv installed")

In [None]:
# Install dependencies
!uv pip install --system -e .

# Suppress warnings
import warnings
warnings.filterwarnings('ignore', message='divide by zero')
warnings.filterwarnings('ignore', category=SyntaxWarning)  # pydub regex warnings

import torch
import pytorch_lightning as pl
print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")
print("✓ Dependencies installed")

## GPU Check

In [None]:
!nvidia-smi

import torch
if not torch.cuda.is_available():
    print("\n⚠️  NO GPU! Enable GPU: Runtime → Change runtime type → T4 GPU")
    raise RuntimeError("GPU required")

print(f"\n✓ GPU: {torch.cuda.get_device_name(0)}")
print(f"✓ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Download MERT model (cached after first download)
from transformers import AutoModel

print("Downloading MERT-95M (~380MB)...")
model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
print("✓ MERT-95M cached")

del model
torch.cuda.empty_cache()

In [None]:
import json
import shutil
import random
from pathlib import Path
from tqdm import tqdm

# Create local data directories
LOCAL_DATA = Path('/tmp/training_data')
LOCAL_AUDIO = Path('/tmp/audio_segments')
LOCAL_MIDI = Path('/tmp/midi_segments')

LOCAL_DATA.mkdir(exist_ok=True)
LOCAL_AUDIO.mkdir(exist_ok=True)
LOCAL_MIDI.mkdir(exist_ok=True)

def copy_files_and_update_annotations(input_jsonl, output_jsonl, n_samples=10000, seed=42):
    """
    Copy only the files needed for this subset to /tmp/ and update paths
    """
    # Load annotations
    with open(input_jsonl) as f:
        data = [json.loads(line) for line in f if line.strip()]
    
    print(f"Original: {len(data):,} samples")
    
    # Subsample if needed
    if n_samples < len(data):
        random.seed(seed)
        data = random.sample(data, n_samples)
        print(f"Subsampled to {len(data):,} samples ({len(data)/len(data)*100:.1f}%)")
    
    # Copy files and update paths
    updated_data = []
    copied_audio = set()
    copied_midi = set()
    
    print(f"Copying {len(data):,} audio/MIDI files to /tmp/...")
    for item in tqdm(data, desc="Copying files"):
        # Copy audio
        audio_src = Path(item['audio_path'])
        audio_dst = LOCAL_AUDIO / audio_src.name
        
        if audio_src.name not in copied_audio:
            if audio_src.exists():
                shutil.copy2(audio_src, audio_dst)
                copied_audio.add(audio_src.name)
            else:
                print(f"\n  Warning: Audio not found: {audio_src}")
        
        # Copy MIDI
        midi_src = Path(item['midi_path'])
        midi_dst = LOCAL_MIDI / midi_src.name
        
        if midi_src.name not in copied_midi:
            if midi_src.exists():
                shutil.copy2(midi_src, midi_dst)
                copied_midi.add(midi_src.name)
            else:
                print(f"\n  Warning: MIDI not found: {midi_src}")
        
        # Update paths to point to /tmp/
        item['audio_path'] = str(audio_dst)
        item['midi_path'] = str(midi_dst)
        updated_data.append(item)
    
    # Save updated annotations
    with open(output_jsonl, 'w') as f:
        for item in updated_data:
            f.write(json.dumps(item) + '\n')
    
    print(f"Copied {len(copied_audio):,} unique audio files")
    print(f"Copied {len(copied_midi):,} unique MIDI files")
    
    return len(updated_data)

print("="*70)
print("COPYING DATA TO LOCAL STORAGE")
print("="*70)
print("\nThis will copy ~10K audio + MIDI files (~20-30GB) to /tmp/")
print("Expected time: 5-10 minutes")
print()

# Copy training data (10K samples)
print("Training set:")
n_train = copy_files_and_update_annotations(
    f'{ANNOTATIONS_ROOT}/synthetic_train_filtered.jsonl',
    LOCAL_DATA / 'synthetic_train_filtered.jsonl',
    n_samples=10000
)

print("\nValidation set:")
n_val = copy_files_and_update_annotations(
    f'{ANNOTATIONS_ROOT}/synthetic_val_filtered.jsonl',
    LOCAL_DATA / 'synthetic_val_filtered.jsonl',
    n_samples=999999  # Copy all
)

print("\nTest set:")
n_test = copy_files_and_update_annotations(
    f'{ANNOTATIONS_ROOT}/synthetic_test_filtered.jsonl',
    LOCAL_DATA / 'synthetic_test_filtered.jsonl',
    n_samples=999999  # Copy all
)

# Check disk usage
import subprocess
result = subprocess.run(['df', '-h', '/tmp'], capture_output=True, text=True)
print("\n" + "="*70)
print("DISK USAGE")
print("="*70)
print(result.stdout)

print("="*70)
print("✓ DATA COPIED TO LOCAL STORAGE")
print("="*70)
print(f"\nAnnotation files (updated paths to /tmp/):")
print(f"  Train: {n_train:,} samples")
print(f"  Val:   {n_val:,} samples")
print(f"  Test:  {n_test:,} samples")
print(f"\nAll files now on fast local SSD")
print(f"Expected speedup: 10-20× faster than Google Drive")
!ls -lh /tmp/training_data/
!echo ""
!du -sh /tmp/audio_segments /tmp/midi_segments

## Preflight Check

Verify data and test audio/MIDI loading before training

In [None]:
import json
from pathlib import Path

print("="*70)
print("PREFLIGHT CHECK")
print("="*70)

# 1. Check data files exist and are readable
print("\n1. Checking data files...")
for split in ['train', 'val', 'test']:
    path = f'/tmp/training_data/synthetic_{split}_filtered.jsonl'
    if not Path(path).exists():
        print(f"  ✗ {split}: FILE NOT FOUND")
        raise FileNotFoundError(f"Missing: {path}")
    
    # Load first line
    with open(path) as f:
        first_line = f.readline()
        sample = json.loads(first_line)
    
    # Check required fields
    required = ['audio_path', 'midi_path', 'labels']
    missing = [f for f in required if f not in sample]
    if missing:
        print(f"  ✗ {split}: Missing fields {missing}")
        raise ValueError(f"Invalid annotation format in {path}")
    
    print(f"  ✓ {split}: {Path(path).stat().st_size / 1024 / 1024:.1f} MB")

# 2. Test ACTUAL audio/MIDI loading (1 sample from Drive - will be slow)
print("\n2. Testing actual data loading (from Google Drive)...")
print("   This will be slow (~30-60 seconds) - validating Drive access works")

from src.data.audio_processing import load_audio, normalize_audio
from src.data.midi_processing import load_midi, align_midi_to_audio, encode_octuple_midi

audio_path = sample['audio_path']
midi_path = sample['midi_path']

try:
    # Load audio
    audio, sr = load_audio(audio_path, sr=24000)
    audio = normalize_audio(audio)
    print(f"  ✓ Audio loaded: {len(audio)} samples @ {sr} Hz from Drive")
    
    # Load MIDI
    midi = load_midi(midi_path)
    audio_duration = len(audio) / sr
    midi = align_midi_to_audio(midi, audio_duration)
    tokens = encode_octuple_midi(midi)
    print(f"  ✓ MIDI loaded: {len(tokens)} tokens from Drive")
    
except Exception as e:
    print(f"  ✗ Data loading failed: {e}")
    print(f"     Audio path: {audio_path}")
    print(f"     MIDI path: {midi_path}")
    raise

# 3. Test config file
print("\n3. Testing config file...")
try:
    import yaml
    with open('configs/experiment_10k.yaml') as f:
        config = yaml.safe_load(f)
    print(f"  ✓ Config loaded")
    print(f"    Dimensions: {config['data']['dimensions']}")
    print(f"    Batch size: {config['data']['batch_size']}")
    print(f"    Num workers: {config['data']['num_workers']} (must be 0 for Drive)")
except Exception as e:
    print(f"  ✗ Config loading failed: {e}")
    raise

# 4. Test model instantiation
print("\n4. Testing model instantiation...")
try:
    from src.models.lightning_module import PerformanceEvaluationModel
    
    # Test each mode
    for mode in ['audio', 'midi', 'fusion']:
        model_config = config['model'].copy()
        mode_overrides = config['modes'][mode]
        model_config.update(mode_overrides)
        
        model = PerformanceEvaluationModel(
            dimension_names=config['data']['dimensions'],
            **model_config
        )
        params = sum(p.numel() for p in model.parameters()) / 1e6
        print(f"  ✓ {mode}: {params:.1f}M params")
        del model
    
    torch.cuda.empty_cache()
except Exception as e:
    print(f"  ✗ Model instantiation failed: {e}")
    raise

print("\n" + "="*70)
print("✓ ALL PREFLIGHT CHECKS PASSED")
print("="*70)
print("\nGoogle Drive access verified - training will be slow but functional")
print("Expected: 2-3 hours per epoch with num_workers=0")

## Experiment 1: Audio-Only (~2 hours)

In [None]:
%%time
!python train.py --config configs/experiment_10k.yaml --mode audio

## Experiment 2: MIDI-Only (~1.5 hours)

In [None]:
%%time
!python train.py --config configs/experiment_10k.yaml --mode midi

## Experiment 3: Fusion (~2.5 hours)

In [None]:
%%time
!python train.py --config configs/experiment_10k.yaml --mode fusion

## Compare Results

In [None]:
import pytorch_lightning as pl
from src.models.lightning_module import PerformanceEvaluationModel
from src.data.dataset import create_dataloaders
from pathlib import Path

# Load all 3 models
models = {}
for mode in ['audio', 'midi', 'fusion']:
    ckpt_dir = Path(f'/content/drive/MyDrive/crescendai_checkpoints/{mode}_10k')
    ckpts = list(ckpt_dir.glob('*.ckpt'))
    if ckpts:
        latest = sorted(ckpts)[-1]
        print(f"Loading {mode}: {latest.name}")
        models[mode] = PerformanceEvaluationModel.load_from_checkpoint(str(latest))
        models[mode].eval()
        models[mode] = models[mode].cuda()
    else:
        print(f"⚠️  No checkpoint found for {mode}")

# Create test dataloader
_, _, test_loader = create_dataloaders(
    train_annotation_path='/tmp/training_data/synthetic_train_filtered.jsonl',
    val_annotation_path='/tmp/training_data/synthetic_val_filtered.jsonl',
    test_annotation_path='/tmp/training_data/synthetic_test_filtered.jsonl',
    dimension_names=['note_accuracy', 'rhythmic_precision', 'tone_quality'],
    batch_size=8,
    num_workers=0,
    augmentation_config=None,
    audio_sample_rate=24000,
    max_audio_length=240000,
    max_midi_events=512,
)

# Evaluate each model
trainer = pl.Trainer(accelerator='auto', devices='auto', precision=16)
results = {}

for mode, model in models.items():
    print(f"\nEvaluating {mode}...")
    test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
    results[mode] = test_results[0]

print("\n" + "="*70)
print("COMPARISON")
print("="*70)
print(f"{'Dimension':<25} {'Audio r':<12} {'MIDI r':<12} {'Fusion r':<12} {'Gain'}")
print("-"*70)

for dim in ['note_accuracy', 'rhythmic_precision', 'tone_quality']:
    audio_r = results.get('audio', {}).get(f'test_pearson_{dim}', 0)
    midi_r = results.get('midi', {}).get(f'test_pearson_{dim}', 0)
    fusion_r = results.get('fusion', {}).get(f'test_pearson_{dim}', 0)
    gain = fusion_r - max(audio_r, midi_r)
    
    print(f"{dim:<25} {audio_r:>11.3f} {midi_r:>11.3f} {fusion_r:>11.3f} {gain:>+11.3f}")

avg_gain = sum(
    results.get('fusion', {}).get(f'test_pearson_{dim}', 0) - 
    max(results.get('audio', {}).get(f'test_pearson_{dim}', 0),
        results.get('midi', {}).get(f'test_pearson_{dim}', 0))
    for dim in ['note_accuracy', 'rhythmic_precision', 'tone_quality']
) / 3

print("-"*70)
print(f"Average fusion gain: {avg_gain:+.3f} ({avg_gain*100:+.1f}%)")
print("="*70)

if avg_gain > 0.05:
    print("\n✓ SUCCESS: Fusion shows clear multi-modal advantage!")
else:
    print("\n⚠️  WARNING: Fusion gain is marginal. Check fusion implementation.")

In [None]:
import pytorch_lightning as pl
from src.models.lightning_module import PerformanceEvaluationModel
from src.data.dataset import create_dataloaders
from pathlib import Path

# Load all 3 models
models = {}
for mode in ['audio', 'midi', 'fusion']:
    ckpt_dir = Path(f'/content/drive/MyDrive/crescendai_checkpoints/{mode}_10k')
    ckpts = list(ckpt_dir.glob('*.ckpt'))
    if ckpts:
        latest = sorted(ckpts)[-1]
        print(f"Loading {mode}: {latest.name}")
        models[mode] = PerformanceEvaluationModel.load_from_checkpoint(str(latest))
        models[mode].eval()
        models[mode] = models[mode].cuda()
    else:
        print(f"⚠️  No checkpoint found for {mode}")

# Create test dataloader
_, _, test_loader = create_dataloaders(
    train_annotation_path='/tmp/training_data/synthetic_train_filtered.jsonl',
    val_annotation_path='/tmp/training_data/synthetic_val_filtered.jsonl',
    test_annotation_path='/tmp/training_data/synthetic_test_filtered.jsonl',
    dimension_names=['note_accuracy', 'rhythmic_precision', 'tone_quality'],
    batch_size=8,
    num_workers=0,
    augmentation_config=None,
    audio_sample_rate=24000,
    max_audio_length=240000,
    max_midi_events=512,
)

# Evaluate each model
trainer = pl.Trainer(accelerator='auto', devices='auto', precision=16)
results = {}

for mode, model in models.items():
    print(f"\nEvaluating {mode}...")
    test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
    results[mode] = test_results[0]

print("\n" + "="*70)
print("COMPARISON")
print("="*70)
print(f"{'Dimension':<25} {'Audio r':<12} {'MIDI r':<12} {'Fusion r':<12} {'Gain'}")
print("-"*70)

for dim in ['note_accuracy', 'rhythmic_precision', 'tone_quality']:
    audio_r = results.get('audio', {}).get(f'test_pearson_{dim}', 0)
    midi_r = results.get('midi', {}).get(f'test_pearson_{dim}', 0)
    fusion_r = results.get('fusion', {}).get(f'test_pearson_{dim}', 0)
    gain = fusion_r - max(audio_r, midi_r)
    
    print(f"{dim:<25} {audio_r:>11.3f} {midi_r:>11.3f} {fusion_r:>11.3f} {gain:>+11.3f}")

avg_gain = sum(
    results.get('fusion', {}).get(f'test_pearson_{dim}', 0) - 
    max(results.get('audio', {}).get(f'test_pearson_{dim}', 0),
        results.get('midi', {}).get(f'test_pearson_{dim}', 0))
    for dim in ['note_accuracy', 'rhythmic_precision', 'tone_quality']
) / 3

print("-"*70)
print(f"Average fusion gain: {avg_gain:+.3f} ({avg_gain*100:+.1f}%)")
print("="*70)

if avg_gain > 0.05:
    print("\n✓ SUCCESS: Fusion shows clear multi-modal advantage!")
else:
    print("\n⚠️  WARNING: Fusion gain is marginal. Check fusion implementation.")