# PercePiano SOTA Replica Training - VirtuosoNet Features

**Goal**: Replicate PercePiano's SOTA results (R-squared = 0.35-0.40) using the exact VirtuosoNet features.

## Attribution

This notebook implements the architecture from:

> **PercePiano: A Benchmark for Perceptual Evaluation of Piano Performance**  
> Park, Jongho and Kim, Dasaem et al.  
> ISMIR 2024 / Nature Scientific Reports 2024  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Critical Fixes Applied (Dec 2024)

| Parameter | Previous (Broken) | Fixed Value | Source |
|-----------|-------------------|-------------|--------|
| input_size | 79 | **84** | 79 base + 5 unnorm features |
| learning_rate | 2.5e-5 | **1e-4** | parser.py:119 |
| weight_decay | 0.01 | **1e-5** | parser.py:135 |
| batch_size | 8 | **32** | parser.py:107 |
| gradient_clip_val | 1.0 | **2.0** | parser.py:159 |
| Voice LSTM input | 512-dim (note output) | **256-dim (embeddings)** | encoder_score.py:496-516 |
| Key augmentation | estimated pitch range | **midi_pitch_unnorm** | data_for_training.py:412-419 |

**Feature Layout (84-dim)**:
- Indices 0-78: Base VirtuosoNet features (z-score normalized where applicable)
- Index 79: midi_pitch_unnorm (raw MIDI pitch 21-108, for key augmentation)
- Index 80-83: duration_unnorm, beat_importance_unnorm, measure_length_unnorm, following_rest_unnorm

**Most Critical Fixes**: 
1. Voice LSTM now receives 256-dim projected embeddings in parallel with Note LSTM, NOT the 512-dim Note LSTM output sequentially.
2. Key augmentation now uses midi_pitch_unnorm (raw MIDI pitch) to calculate valid shift range, matching original PercePiano.

## What This Notebook Does

1. Downloads PercePiano data and scores from Google Drive
2. Runs VirtuosoNet preprocessing to extract 84-dim features (79 base + 5 unnorm)
3. Trains the faithful PercePiano replica with correct features and hyperparameters
4. Evaluates against SOTA baselines
5. Saves the trained model as a **Teacher Model** for pseudo-labeling MAESTRO

## Expected Results

- **Target R-squared**: 0.35-0.40 (piece-split)
- **Training time**: ~1-2 hours on T4/A100
- **Model size**: ~8-10M parameters

## Step 1: Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow.")

In [None]:
# Install rclone for Google Drive sync
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

In [None]:
import os
from pathlib import Path
import subprocess

# Paths
CHECKPOINT_ROOT = '/tmp/checkpoints/percepiano_replica'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_replica'
GDRIVE_DATA_PATH = 'gdrive:percepiano_data'
DATA_ROOT = Path('/tmp/percepiano_data')

print("="*70)
print("PERCEPIANO REPLICA - SOTA REPRODUCTION")
print("="*70)
print("Reference: Park et al., 'PercePiano', ISMIR/Nature 2024")
print("GitHub: https://github.com/JonghoKimSNU/PercePiano")
print("="*70)

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
print("\nChecking rclone configuration...")
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("  rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
    
    # Restore existing checkpoints
    print("\nRestoring checkpoints from Google Drive (if any)...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
        capture_output=False
    )
else:
    print("  rclone 'gdrive' remote: NOT CONFIGURED")
    print("  Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nCheckpoint directory: {CHECKPOINT_ROOT}")
print(f"rclone available: {RCLONE_AVAILABLE}")

## Step 2: Download PercePiano Data

In [None]:
from pathlib import Path
import subprocess
import json

DATA_ROOT = Path('/tmp/percepiano_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Download PercePiano data
train_file = DATA_ROOT / 'percepiano_train.json'
if train_file.exists():
    print(f"Data already exists at {DATA_ROOT}")
else:
    print("Downloading PercePiano data from Google Drive...")
    result = subprocess.run(
        ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
        capture_output=False
    )

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

for split in ['train', 'val', 'test']:
    path = DATA_ROOT / f'percepiano_{split}.json'
    if path.exists():
        with open(path) as f:
            data = json.load(f)
        has_scores = sum(1 for s in data if s.get('score_path'))
        print(f"{split}: {len(data)} samples ({has_scores} with score paths)")
    else:
        print(f"ERROR: {path} not found!")

# Check MIDI and score files
midi_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'all_2rounds'
score_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'score_xml'

if midi_dir.exists():
    midi_files = list(midi_dir.glob('*.mid'))
    print(f"\nMIDI files: {len(midi_files)}")
else:
    raise FileNotFoundError(f"MIDI directory not found at {midi_dir}")

if score_dir.exists():
    score_files = list(score_dir.glob('*.musicxml'))
    print(f"Score files: {len(score_files)}")
    if len(score_files) == 0:
        raise FileNotFoundError(f"No MusicXML files found in {score_dir}")
else:
    raise FileNotFoundError(
        f"Score directory not found at {score_dir}\n"
        "Run: rclone copy gdrive:percepiano_data/PercePiano/virtuoso/data/score_xml/ {score_dir}/"
    )

# Store paths for later
MIDI_DIR = midi_dir
SCORE_DIR = score_dir

print("\n[OK] All required files present")

## Step 3: Update Paths for Thunder Compute

In [None]:
import json
from pathlib import Path

# All 19 PercePiano dimensions
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity", "timbre_variety",
    "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

def update_paths_for_thunder(data_root: Path):
    """Update paths in JSON files for Thunder Compute environment."""
    
    for split in ['train', 'val', 'test']:
        path = data_root / f'percepiano_{split}.json'
        
        with open(path) as f:
            data = json.load(f)
        
        for sample in data:
            # Update MIDI path
            filename = Path(sample['midi_path']).name
            sample['midi_path'] = str(MIDI_DIR / filename)
            
            # Make sure scores dict uses all 19 dimensions
            if 'percepiano_scores' in sample:
                pp_scores = sample['percepiano_scores'][:19]
                sample['scores'] = {
                    dim: pp_scores[i]
                    for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
                }
        
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
        
        print(f"Updated {split}: {len(data)} samples")

update_paths_for_thunder(DATA_ROOT)

# Verify
with open(DATA_ROOT / 'percepiano_train.json') as f:
    sample = json.load(f)[0]

print(f"\nSample MIDI path: {sample['midi_path']}")
print(f"Sample score path: {sample.get('score_path', 'N/A')}")
print(f"Dimensions: {len(sample['scores'])}")
print(f"MIDI exists: {Path(sample['midi_path']).exists()}")

## Step 4: Pre-Flight Validation

In [None]:
# Pre-flight validation - FAIL FAST if requirements not met
from src.utils.preflight_validation import run_preflight_validation, PreflightValidationError

print("="*60)
print("PRE-FLIGHT VALIDATION")
print("="*60)

try:
    run_preflight_validation(
        data_dir=DATA_ROOT,
        score_dir=SCORE_DIR,
        pretrained_checkpoint=None,  # PercePiano replica doesn't use pre-trained encoder
        require_pretrained=False,    # Training from scratch with Bi-LSTM
        min_score_coverage=0.95,
    )
    print("\n[OK] Pre-flight validation PASSED - ready to train")
except PreflightValidationError as e:
    print(f"\n[VALIDATION FAILED]\n{e}")
    raise RuntimeError("Fix the issues above before training")

## Step 5: Training Configuration

Configuration matched exactly to PercePiano paper's `han_bigger256_concat.yml`:

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# PercePiano SOTA Configuration
# From: https://github.com/JonghoKimSNU/PercePiano
# File: virtuoso/ymls/shared/label19/han_bigger256_concat.yml
# Training params: virtuoso/virtuoso/parser.py
CONFIG = {
    # Data
    'data_dir': str(DATA_ROOT),
    'score_dir': str(SCORE_DIR),
    'vnet_data_dir': str(DATA_ROOT / 'percepiano_vnet'),  # VirtuosoNet preprocessed features
    
    # VirtuosoNet Feature Dimension
    # 79 base features (14 scalar + 13 pitch + 5 tempo + 4 dynamic + 9 time_sig + 6 slur_beam + 17 composer + 9 notation + 2 tempo_primo)
    # + 5 unnorm features (midi_pitch_unnorm, duration_unnorm, beat_importance_unnorm, measure_length_unnorm, following_rest_unnorm)
    # = 84 total features
    # The unnorm features preserve raw values BEFORE normalization, critical for key augmentation
    'input_size': 84,  # VirtuosoNet features (79 base + 5 unnorm)
    
    # HAN Architecture (matched to PercePiano han_bigger256_concat.yml)
    'hidden_size': 256,        # PercePiano: 256 for all levels
    'note_layers': 2,          # PercePiano: 2
    'voice_layers': 2,         # PercePiano: 2
    'beat_layers': 2,          # PercePiano: 2
    'measure_layers': 1,       # PercePiano: 1
    'num_attention_heads': 8,  # PercePiano: 8
    'final_hidden': 128,       # PercePiano: 128
    
    # Training (matched to original PercePiano parser.py)
    'learning_rate': 1e-4,     # Original: parser.py:119 (was 2.5e-5 - WRONG!)
    'weight_decay': 1e-5,      # Original: parser.py:135 (was 0.01 - WRONG!)
    'dropout': 0.2,            # PercePiano: 0.2
    'batch_size': 32,          # Original: parser.py:107 (was 8 - WRONG!)
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,  # Original: parser.py:159 (was 1.0 - WRONG!)
    'precision': '16-mixed',
    
    # Dataset
    'max_notes': 1024,
    
    # Checkpoints
    'checkpoint_dir': CHECKPOINT_ROOT,
    'gdrive_checkpoint': GDRIVE_CHECKPOINT_PATH,
}

print("="*70)
print("PERCEPIANO REPLICA CONFIGURATION (FIXED)")
print("="*70)
print("Reference: han_bigger256_concat.yml + parser.py from PercePiano repo")
print("="*70)
print("\nCRITICAL FIXES APPLIED:")
print("  - input_size: 84 (79 base + 5 unnorm, was 79)")
print("  - learning_rate: 1e-4 (was 2.5e-5, parser.py:119)")
print("  - weight_decay: 1e-5 (was 0.01, parser.py:135)")
print("  - batch_size: 32 (was 8, parser.py:107)")
print("  - gradient_clip_val: 2.0 (was 1.0, parser.py:159)")
print("  - Voice LSTM: Now receives 256-dim embeddings (was 512-dim note output)")
print("  - Key augmentation: Now uses midi_pitch_unnorm for pitch range")
print("="*70)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")
print("="*70)

## Step 6: VirtuosoNet Preprocessing (CRITICAL)

This step extracts the 84-dim VirtuosoNet features used in the original PercePiano paper:
- **79 base features**: Normalized where applicable (z-score normalization)
- **5 unnorm features**: Raw values preserved BEFORE normalization (critical for key augmentation)

The unnorm features are:
- `midi_pitch_unnorm`: Raw MIDI pitch (21-108) for accurate key augmentation
- `duration_unnorm`, `beat_importance_unnorm`, `measure_length_unnorm`, `following_rest_unnorm`

If preprocessing has already been done, this cell will skip.

In [None]:
from pathlib import Path
import subprocess
import shutil
import pickle
import os

# Expected feature dimension (79 base + 5 unnorm = 84)
EXPECTED_FEATURE_DIM = 84

def check_feature_dimension(vnet_dir: Path) -> int:
    """Check the feature dimension of existing preprocessed data."""
    train_dir = vnet_dir / 'train'
    if not train_dir.exists():
        return 0
    
    pkl_files = list(train_dir.glob('*.pkl'))
    if not pkl_files:
        return 0
    
    # Load first file and check dimension
    with open(pkl_files[0], 'rb') as f:
        data = pickle.load(f)
    
    if 'input' in data:
        return data['input'].shape[1]
    return 0

# First, ensure VirtuosoNet modules are available by cloning PercePiano repo
virtuoso_module_path = Path('/tmp/crescendai/model/data/raw/PercePiano/virtuoso/virtuoso/pyScoreParser')
percepiano_path = Path('/tmp/crescendai/model/data/raw/PercePiano')

if not (virtuoso_module_path / 'feature_extraction.py').exists():
    print("VirtuosoNet modules not found. Cloning PercePiano repository...")
    
    # Remove existing directory if it exists (might be empty or corrupted)
    if percepiano_path.exists():
        shutil.rmtree(percepiano_path)
    
    percepiano_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Clone the repo
    result = subprocess.run(
        ['git', 'clone', '--depth', '1',
         'https://github.com/JonghoKimSNU/PercePiano.git',
         str(percepiano_path)],
        capture_output=True,
        text=True
    )
    
    if result.returncode != 0:
        print(f"Failed to clone PercePiano: {result.stderr}")
        raise RuntimeError("Could not clone PercePiano repository")
    
    print("PercePiano cloned successfully!")
    print(f"VirtuosoNet modules at: {virtuoso_module_path}")
else:
    print(f"VirtuosoNet modules already available at {virtuoso_module_path}")

# Check if VirtuosoNet features exist AND have correct dimension
vnet_dir = Path(CONFIG['vnet_data_dir'])
vnet_train_dir = vnet_dir / 'train'

needs_preprocessing = False
if vnet_train_dir.exists() and list(vnet_train_dir.glob('*.pkl')):
    # Check feature dimension
    current_dim = check_feature_dimension(vnet_dir)
    
    if current_dim == EXPECTED_FEATURE_DIM:
        print(f"\nVirtuosoNet features already exist with correct dimension ({current_dim}-dim)")
        print(f"  train: {len(list((vnet_dir / 'train').glob('*.pkl')))} samples")
        print(f"  val: {len(list((vnet_dir / 'val').glob('*.pkl')))} samples")
        print(f"  test: {len(list((vnet_dir / 'test').glob('*.pkl')))} samples")
    else:
        print(f"\n[WARNING] Existing features have wrong dimension: {current_dim} (expected {EXPECTED_FEATURE_DIM})")
        print(f"  This is likely old 79-dim data without unnorm features.")
        print(f"  Deleting old data and re-preprocessing...")
        
        # Delete old preprocessed data
        shutil.rmtree(vnet_dir)
        needs_preprocessing = True
else:
    needs_preprocessing = True

if needs_preprocessing:
    print("\nVirtuosoNet features not found or outdated. Running preprocessing...")
    print(f"This extracts {EXPECTED_FEATURE_DIM}-dim features (79 base + 5 unnorm) for key augmentation.")
    print("")
    
    # Run preprocessing script
    # Pass DATA_ROOT directly - the script will find JSON files and score_xml there
    result = subprocess.run(
        ['python', 'scripts/preprocess_percepiano_vnet.py',
         '--data_root', str(DATA_ROOT),
         '--output_dir', str(vnet_dir)],
        cwd='/tmp/crescendai/model',
        capture_output=True,
        text=True
    )
    
    print(result.stdout)
    if result.returncode != 0:
        print(f"Preprocessing failed: {result.stderr}")
        print("\nNOTE: VirtuosoNet preprocessing requires MusicXML parsing.")
        print("If this fails, you may need to manually align MIDI files to scores.")
        print("")
        print("Alternative: Download pre-processed features from Google Drive:")
        print(f"  rclone copy gdrive:percepiano_data/percepiano_vnet {vnet_dir}")
    else:
        print("Preprocessing complete!")
        
        # Verify new dimension
        new_dim = check_feature_dimension(vnet_dir)
        if new_dim == EXPECTED_FEATURE_DIM:
            print(f"[OK] Features have correct dimension: {new_dim}")
        else:
            print(f"[ERROR] Features still have wrong dimension: {new_dim} (expected {EXPECTED_FEATURE_DIM})")

# Verify features exist with correct dimension
if vnet_train_dir.exists():
    num_train = len(list((vnet_dir / 'train').glob('*.pkl')))
    final_dim = check_feature_dimension(vnet_dir)
    print(f"\nVirtuosoNet features ready: {num_train} training samples, {final_dim}-dim")
    
    if final_dim != EXPECTED_FEATURE_DIM:
        raise RuntimeError(
            f"Feature dimension mismatch: got {final_dim}, expected {EXPECTED_FEATURE_DIM}.\n"
            f"Delete {vnet_dir} and re-run this cell."
        )
else:
    raise RuntimeError(
        f"VirtuosoNet features not available at {vnet_dir}.\n"
        "Run preprocessing or download from Google Drive."
    )

## Step 6.1: Data Quality Diagnostics (CRITICAL)

This cell performs comprehensive data quality checks on the preprocessed VirtuosoNet features.
If you see RED warnings, they indicate potential causes for poor training performance.

In [None]:
"""
DATA QUALITY DIAGNOSTICS
========================
This cell checks for common issues that cause training to fail:
1. Insufficient preprocessed samples
2. NaN/Inf values in features
3. Label range issues (should be 0-1)
4. Note location format issues (critical for hierarchical attention)
"""

import pickle
import numpy as np
from pathlib import Path
from rich.console import Console
from rich.table import Table
import random

console = Console()
vnet_dir = Path(CONFIG['vnet_data_dir'])

def diagnose_sample(pkl_path: Path) -> dict:
    """Analyze a single preprocessed sample for issues."""
    issues = []
    
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    
    result = {
        'name': pkl_path.stem,
        'issues': issues,
    }
    
    # Check input features
    input_feat = data.get('input')
    if input_feat is None:
        issues.append("CRITICAL: Missing 'input' key")
        return result
    
    result['shape'] = input_feat.shape
    result['num_notes'] = input_feat.shape[0]
    result['feature_dim'] = input_feat.shape[1]
    
    # Check for NaN/Inf
    nan_count = np.isnan(input_feat).sum()
    inf_count = np.isinf(input_feat).sum()
    if nan_count > 0:
        issues.append(f"CRITICAL: {nan_count} NaN values in features")
    if inf_count > 0:
        issues.append(f"CRITICAL: {inf_count} Inf values in features")
    
    # Check feature dimension
    if input_feat.shape[1] != 84:
        issues.append(f"WARNING: Feature dim is {input_feat.shape[1]}, expected 84")
    
    # Check labels
    labels = data.get('labels')
    if labels is None:
        issues.append("CRITICAL: Missing 'labels' key")
    else:
        result['labels'] = labels
        result['label_min'] = float(labels.min())
        result['label_max'] = float(labels.max())
        
        if labels.min() < 0 or labels.max() > 1:
            issues.append(f"CRITICAL: Labels out of [0,1] range: [{labels.min():.3f}, {labels.max():.3f}]")
        
        # Check for constant labels (zero variance)
        for i, val in enumerate(labels):
            if np.std([val]) == 0:  # Single value, will be compared across samples later
                pass  # Individual sample check not meaningful
    
    # Check note locations (CRITICAL for hierarchical attention)
    note_loc = data.get('note_location')
    if note_loc is None:
        issues.append("CRITICAL: Missing 'note_location' key")
    else:
        for key in ['beat', 'measure', 'voice']:
            if key not in note_loc:
                issues.append(f"CRITICAL: Missing '{key}' in note_location")
            else:
                indices = np.array(note_loc[key])
                result[f'{key}_min'] = int(indices.min())
                result[f'{key}_max'] = int(indices.max())
                
                # Check for gaps in beat indices (critical for boundary detection)
                if key == 'beat':
                    unique = np.unique(indices)
                    expected = np.arange(unique.min(), unique.max() + 1)
                    missing = set(expected) - set(unique)
                    if missing and len(missing) <= 10:
                        issues.append(f"WARNING: Beat indices have gaps: missing {sorted(missing)[:5]}...")
                    elif missing:
                        issues.append(f"WARNING: Beat indices have {len(missing)} gaps")
                    
                    # Check if starts from 1 (expected) or 0
                    if indices.min() == 0:
                        issues.append("INFO: Beat indices start from 0 (hierarchy_utils expects 1)")
    
    return result

# Count samples per split
print("="*70)
print("DATA QUALITY DIAGNOSTICS")
print("="*70)

sample_counts = {}
for split in ['train', 'val', 'test']:
    split_dir = vnet_dir / split
    if split_dir.exists():
        pkl_files = list(split_dir.glob('*.pkl'))
        sample_counts[split] = len(pkl_files)
    else:
        sample_counts[split] = 0
        console.print(f"[red]CRITICAL: {split} directory not found![/red]")

print("\n1. SAMPLE COUNTS")
print("-" * 40)
for split, count in sample_counts.items():
    status = "[green]OK[/green]" if count > 0 else "[red]MISSING[/red]"
    console.print(f"  {split:5s}: {count:4d} samples  {status}")

total_samples = sum(sample_counts.values())
if total_samples < 100:
    console.print(f"\n[red]CRITICAL: Only {total_samples} total samples - likely preprocessing failures![/red]")
    console.print("[yellow]Check preprocessing logs for errors.[/yellow]")

# Sample 5 files from each split for detailed analysis
print("\n2. SAMPLE QUALITY CHECK")
print("-" * 40)

all_issues = []
all_results = []

for split in ['train', 'val', 'test']:
    split_dir = vnet_dir / split
    if not split_dir.exists():
        continue
    
    pkl_files = list(split_dir.glob('*.pkl'))
    if not pkl_files:
        continue
    
    # Sample up to 5 files
    sample_files = random.sample(pkl_files, min(5, len(pkl_files)))
    
    console.print(f"\n[cyan]{split.upper()} split ({len(sample_files)} samples checked):[/cyan]")
    
    for pkl_path in sample_files:
        result = diagnose_sample(pkl_path)
        all_results.append(result)
        
        if result['issues']:
            for issue in result['issues']:
                console.print(f"  {pkl_path.stem[:30]:30s} - [red]{issue}[/red]")
                all_issues.append((split, pkl_path.stem, issue))
        else:
            console.print(f"  {pkl_path.stem[:30]:30s} - [green]OK[/green] (shape: {result.get('shape')})")

# Aggregate statistics
print("\n3. AGGREGATE STATISTICS")
print("-" * 40)

if all_results:
    # Feature dimensions
    dims = [r['feature_dim'] for r in all_results if 'feature_dim' in r]
    if dims:
        console.print(f"  Feature dimensions: {set(dims)}")
        if len(set(dims)) > 1:
            console.print("[red]  CRITICAL: Inconsistent feature dimensions![/red]")
    
    # Note counts
    notes = [r['num_notes'] for r in all_results if 'num_notes' in r]
    if notes:
        console.print(f"  Notes per sample: min={min(notes)}, max={max(notes)}, mean={np.mean(notes):.0f}")
    
    # Label statistics
    label_mins = [r['label_min'] for r in all_results if 'label_min' in r]
    label_maxs = [r['label_max'] for r in all_results if 'label_max' in r]
    if label_mins:
        console.print(f"  Label range: [{min(label_mins):.3f}, {max(label_maxs):.3f}]")
        if min(label_mins) < 0 or max(label_maxs) > 1:
            console.print("[red]  CRITICAL: Labels outside [0,1] - sigmoid output mismatch![/red]")
    
    # Beat index statistics
    beat_mins = [r['beat_min'] for r in all_results if 'beat_min' in r]
    beat_maxs = [r['beat_max'] for r in all_results if 'beat_max' in r]
    if beat_mins:
        console.print(f"  Beat indices: min={min(beat_mins)}, max={max(beat_maxs)}")
        if min(beat_mins) == 0:
            console.print("[yellow]  INFO: Beat indices start at 0 - check hierarchy_utils compatibility[/yellow]")

# Summary
print("\n4. DIAGNOSTIC SUMMARY")
print("-" * 40)

critical_count = sum(1 for _, _, issue in all_issues if 'CRITICAL' in issue)
warning_count = sum(1 for _, _, issue in all_issues if 'WARNING' in issue)

if critical_count > 0:
    console.print(f"[red]CRITICAL ISSUES: {critical_count}[/red]")
    console.print("[red]Training will likely fail or produce garbage results.[/red]")
elif warning_count > 0:
    console.print(f"[yellow]WARNINGS: {warning_count}[/yellow]")
    console.print("[yellow]Training may work but with degraded performance.[/yellow]")
else:
    console.print("[green]No issues detected - data looks healthy![/green]")

print("="*70)

In [None]:
# Patch PercePiano data_class.py to fix silent failures
# The original code has bare `except:` clauses that swallow errors silently
# This patch replaces them with explicit exception handling

import re
from pathlib import Path

data_class_path = Path('/tmp/crescendai/model/data/raw/PercePiano/virtuoso/virtuoso/pyScoreParser/data_class.py')

if data_class_path.exists():
    print("Patching PercePiano data_class.py to fix silent failures...")
    
    content = data_class_path.read_text()
    original_content = content
    
    # Patch 1: Fix bare except in load_all_piece (line ~95-98)
    # Change from printing error to re-raising with context
    old_pattern1 = r"except Exception as ex:\s+# TODO: TGK: this is ambiguous.*?\s+print\(f'Error while processing \{scores\[n\]\}\. Error type :\{ex\}'\)"
    new_code1 = """except Exception as ex:
                # Re-raise with full context instead of silently continuing
                raise RuntimeError(f'Error loading piece {scores[n]}: {type(ex).__name__}: {ex}') from ex"""
    content = re.sub(old_pattern1, new_code1, content, flags=re.DOTALL)
    
    # Patch 2: Fix bare except in performance alignment (line ~332-335)
    old_pattern2 = r"except:\s+perform_data = None\s+print\(f'Cannot align \{perform\}'\)\s+self\.performances\.append\(None\)"
    new_code2 = """except (ValueError, IndexError, FileNotFoundError, OSError) as e:
                        # Explicit exception handling - re-raise with context
                        raise RuntimeError(f'Alignment failed for {perform}: {type(e).__name__}: {e}') from e"""
    content = re.sub(old_pattern2, new_code2, content)
    
    # Patch 3: Fix bare except in Nakamura alignment (line ~425)
    old_pattern3 = r"except:\s+print\('Error to process \{\}'\.format\(midi_file_path\)\)"
    new_code3 = """except subprocess.CalledProcessError as e:
            print(f'Alignment tool failed for {midi_file_path}: {e}')"""
    content = re.sub(old_pattern3, new_code3, content)
    
    # Patch 4: Fix second bare except in Nakamura alignment retry (line ~435)
    old_pattern4 = r"except:\s+align_success = False\s+print\('Fail to process \{\}'\.format\(midi_file_path\)\)\s+os\.chdir\(current_dir\)"
    new_code4 = """except subprocess.CalledProcessError as e2:
                align_success = False
                raise RuntimeError(f'Alignment tool failed after retry for {midi_file_path}: {e2}') from e2"""
    content = re.sub(old_pattern4, new_code4, content)
    
    if content != original_content:
        data_class_path.write_text(content)
        print("[OK] Patched data_class.py - silent failures now raise explicit exceptions")
    else:
        print("[INFO] data_class.py already patched or patterns not found")
else:
    print("[SKIP] data_class.py not found yet - will be available after cloning")

## Step 7: Create DataLoaders and Model

Using the VirtuosoNet preprocessed features (84-dim: 79 base + 5 unnorm) with the faithful PercePiano replica architecture.

In [None]:
from src.data.percepiano_vnet_dataset import create_vnet_dataloaders
from src.models.percepiano_replica import PercePianoVNetModule

# Create DataLoaders with VirtuosoNet features (84-dim: 79 base + 5 unnorm)
# Note: batch_size defaults to 32 in create_vnet_dataloaders (matched to original PercePiano)
# Key augmentation uses midi_pitch_unnorm (index 79) to calculate valid pitch shift range
# NOTE: Using num_workers=0 to avoid shared memory issues on Thunder Compute
train_loader, val_loader, test_loader = create_vnet_dataloaders(
    data_dir=CONFIG['vnet_data_dir'],
    batch_size=CONFIG['batch_size'],  # 32 (original PercePiano parser.py:107)
    max_notes=CONFIG['max_notes'],
    num_workers=0,  # Avoid shared memory issues on Thunder Compute
)

print(f"Train: {len(train_loader.dataset)} samples")
print(f"Val: {len(val_loader.dataset)} samples")
print(f"Test: {len(test_loader.dataset)} samples")
print(f"Batch size: {CONFIG['batch_size']}")

# Create model with VirtuosoNet features (84-dim input: 79 base + 5 unnorm)
# Note: PercePianoVNetModule defaults are now matched to original PercePiano:
#   - input_size: 84 (79 base + 5 unnorm for key augmentation)
#   - learning_rate: 1e-4 (parser.py:119)
#   - weight_decay: 1e-5 (parser.py:135)
#   - Voice LSTM input: 256-dim embeddings (NOT 512-dim note output)
model = PercePianoVNetModule(
    # VirtuosoNet input dimension (79 base + 5 unnorm)
    input_size=CONFIG['input_size'],  # 84-dim features
    # HAN dimensions (matched to PercePiano han_bigger256_concat.yml)
    hidden_size=CONFIG['hidden_size'],
    note_layers=CONFIG['note_layers'],
    voice_layers=CONFIG['voice_layers'],
    beat_layers=CONFIG['beat_layers'],
    measure_layers=CONFIG['measure_layers'],
    num_attention_heads=CONFIG['num_attention_heads'],
    final_hidden=CONFIG['final_hidden'],
    # Training (matched to original PercePiano parser.py)
    learning_rate=CONFIG['learning_rate'],   # 1e-4 (parser.py:119)
    weight_decay=CONFIG['weight_decay'],     # 1e-5 (parser.py:135)
    dropout=CONFIG['dropout'],
)

# Count parameters
total_params = model.count_parameters()

print("\n" + "="*70)
print("PERCEPIANO REPLICA MODEL (VirtuosoNet Features)")
print("="*70)
print(f"Architecture: Bi-LSTM + HAN (Note -> Voice -> Beat -> Measure)")
print(f"Input features: {CONFIG['input_size']} (79 base + 5 unnorm)")
print(f"Hidden size: {CONFIG['hidden_size']}")
print(f"Total parameters: {total_params:,}")
print(f"")
print(f"CRITICAL FIXES APPLIED:")
print(f"  1. input_size: 84 (79 base + 5 unnorm, matches original _unnorm features)")
print(f"  2. Voice LSTM receives 256-dim embeddings (was 512-dim note output)")
print(f"  3. learning_rate: {CONFIG['learning_rate']} (was 2.5e-5)")
print(f"  4. weight_decay: {CONFIG['weight_decay']} (was 0.01)")
print(f"  5. batch_size: {CONFIG['batch_size']} (was 8)")
print(f"  6. gradient_clip_val: {CONFIG['gradient_clip_val']} (set in Trainer)")
print(f"  7. Key augmentation uses midi_pitch_unnorm for pitch range calculation")
print(f"  8. Beat/measure indices shifted to start from 1 (hierarchy_utils fix)")
print(f"")
print(f"Target R-squared: 0.35-0.40 (piece-split)")
print(f"Dimensions: {len(model.dimensions)}")
print("="*70)

## Step 7.1: Simple Baseline Test (Optional)

This cell tests if a simple linear model can learn anything from the data.
If even a simple model fails (R^2 < 0), the issue is definitely in the data pipeline, not the architecture.

Skip this cell for normal training - only run if diagnosing training failures.

In [None]:
# OPTIONAL: Simple baseline test - skip this cell for normal training
# Only run if you need to verify the data pipeline works

RUN_BASELINE_TEST = False  # Set to True to run this test

if RUN_BASELINE_TEST:
    import torch
    import torch.nn as nn
    from sklearn.metrics import r2_score
    from tqdm import tqdm
    
    print("="*60)
    print("SIMPLE BASELINE TEST")
    print("="*60)
    print("Testing if a simple linear model can learn from the data...")
    print("If this fails, the issue is in the data pipeline, not the architecture.\n")
    
    # Simple linear model: just mean-pool features and predict scores
    class SimpleBaseline(nn.Module):
        def __init__(self, input_dim=84, hidden_dim=64, output_dim=19):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, output_dim),
                nn.Sigmoid()  # Match PercePiano replica
            )
        
        def forward(self, x, mask):
            # x: (batch, seq, 84), mask: (batch, seq)
            # Mean pool over valid notes
            mask_expanded = mask.unsqueeze(-1).float()  # (batch, seq, 1)
            masked_x = x * mask_expanded
            summed = masked_x.sum(dim=1)  # (batch, 84)
            counts = mask_expanded.sum(dim=1).clamp(min=1)  # (batch, 1)
            pooled = summed / counts  # (batch, 84)
            return self.net(pooled)
    
    # Create model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    simple_model = SimpleBaseline().to(device)
    optimizer = torch.optim.Adam(simple_model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    print(f"Simple model parameters: {sum(p.numel() for p in simple_model.parameters()):,}")
    print(f"Device: {device}\n")
    
    # Train for 10 epochs
    num_epochs = 10
    for epoch in range(num_epochs):
        simple_model.train()
        train_losses = []
        
        for batch in train_loader:
            x = batch['input_features'].to(device)
            mask = batch['attention_mask'].to(device)
            targets = batch['scores'].to(device)
            
            preds = simple_model(x, mask)
            loss = criterion(preds, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
        
        # Validate
        simple_model.eval()
        all_preds, all_targets = [], []
        
        with torch.no_grad():
            for batch in val_loader:
                x = batch['input_features'].to(device)
                mask = batch['attention_mask'].to(device)
                targets = batch['scores'].to(device)
                
                preds = simple_model(x, mask)
                all_preds.append(preds.cpu())
                all_targets.append(targets.cpu())
        
        all_preds = torch.cat(all_preds).numpy()
        all_targets = torch.cat(all_targets).numpy()
        
        r2 = r2_score(all_targets, all_preds)
        train_loss = sum(train_losses) / len(train_losses)
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val R2: {r2:.4f}")
    
    print("\n" + "="*60)
    print("BASELINE TEST RESULTS")
    print("="*60)
    
    if r2 > 0:
        print(f"[OK] Simple model achieved R2 = {r2:.4f}")
        print("The data pipeline appears to work. Issue may be in the HAN architecture.")
    else:
        print(f"[FAIL] Simple model has R2 = {r2:.4f} (negative!)")
        print("Even a simple model cannot learn from this data.")
        print("\nPossible causes:")
        print("  1. Labels are in wrong scale (should be 0-1)")
        print("  2. Features contain NaN/Inf values")
        print("  3. Features are all zeros or constant")
        print("  4. Label-feature alignment is wrong")
        print("\nRun the Data Quality Diagnostics cell for more details.")
    print("="*60)
    
    # Clean up
    del simple_model, optimizer
    torch.cuda.empty_cache()
else:
    print("[SKIP] Simple baseline test - set RUN_BASELINE_TEST = True to run")

## Step 8: Configure Trainer

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, Callback
from pytorch_lightning.loggers import TensorBoardLogger
from rich.console import Console
import time
import numpy as np

# Custom callback for detailed training progress with prediction diagnostics
class TrainingProgressCallback(Callback):
    """Custom callback to show detailed training progress and diagnose training issues."""
    
    def __init__(self, log_every_n_batches: int = 10):
        super().__init__()
        self.log_every_n_batches = log_every_n_batches
        self.epoch_start_time = None
        self.train_losses = []
        self.console = Console()
        self.first_batch_logged = False
        self.prediction_variances = []
        
    def on_train_epoch_start(self, trainer, pl_module):
        self.epoch_start_time = time.time()
        self.train_losses = []
        self.first_batch_logged = False
        self.console.print(f"\n[bold cyan]{'='*60}[/bold cyan]")
        self.console.print(f"[bold cyan]Epoch {trainer.current_epoch + 1}/{trainer.max_epochs}[/bold cyan]")
        self.console.print(f"[bold cyan]{'='*60}[/bold cyan]")
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Handle both tensor and dict outputs
        if outputs is not None:
            if isinstance(outputs, dict) and 'loss' in outputs:
                loss_val = outputs['loss'].item() if hasattr(outputs['loss'], 'item') else outputs['loss']
            elif hasattr(outputs, 'item'):
                loss_val = outputs.item()
            else:
                loss_val = float(outputs) if outputs is not None else 0.0
            self.train_losses.append(loss_val)
        
        # DIAGNOSTIC: Log first batch predictions on epoch 0
        if trainer.current_epoch == 0 and batch_idx == 0 and not self.first_batch_logged:
            self.first_batch_logged = True
            self._diagnose_first_batch(trainer, pl_module, batch)
        
        # Log progress
        total_batches = len(trainer.train_dataloader)
        if (batch_idx + 1) % self.log_every_n_batches == 0 or batch_idx == total_batches - 1:
            recent_losses = self.train_losses[-self.log_every_n_batches:]
            avg_loss = sum(recent_losses) / len(recent_losses) if recent_losses else 0
            progress_pct = (batch_idx + 1) / total_batches * 100
            elapsed = time.time() - self.epoch_start_time
            eta = elapsed / (batch_idx + 1) * (total_batches - batch_idx - 1)
            
            # Get current learning rate
            lr = trainer.optimizers[0].param_groups[0]['lr']
            
            self.console.print(
                f"  Step [{batch_idx + 1:4d}/{total_batches}] "
                f"({progress_pct:5.1f}%) | "
                f"Loss: {avg_loss:.4f} | "
                f"LR: {lr:.2e} | "
                f"Elapsed: {elapsed:5.0f}s | "
                f"ETA: {eta:5.0f}s"
            )
    
    def _diagnose_first_batch(self, trainer, pl_module, batch):
        """Diagnose model predictions on first batch to detect issues early."""
        import torch
        
        self.console.print(f"\n[yellow]PREDICTION DIAGNOSTICS (First Batch)[/yellow]")
        
        try:
            pl_module.eval()
            with torch.no_grad():
                # Move batch to device
                device = pl_module.device
                input_features = batch['input_features'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                scores = batch['scores'].to(device)
                
                note_locations = {
                    'beat': batch['note_locations_beat'].to(device),
                    'measure': batch['note_locations_measure'].to(device),
                    'voice': batch['note_locations_voice'].to(device),
                }
                
                # Forward pass
                outputs = pl_module(
                    input_features=input_features,
                    note_locations=note_locations,
                    attention_mask=attention_mask,
                )
                
                preds = outputs['predictions'].cpu().numpy()
                targets = scores.cpu().numpy()
                
                # Prediction statistics
                pred_mean = preds.mean()
                pred_std = preds.std()
                pred_min = preds.min()
                pred_max = preds.max()
                
                target_mean = targets.mean()
                target_std = targets.std()
                
                self.console.print(f"  Predictions: mean={pred_mean:.4f}, std={pred_std:.4f}, range=[{pred_min:.4f}, {pred_max:.4f}]")
                self.console.print(f"  Targets:     mean={target_mean:.4f}, std={target_std:.4f}")
                
                # Check for issues
                if pred_std < 0.01:
                    self.console.print(f"  [red]WARNING: Prediction std < 0.01 - model may be collapsing to constant![/red]")
                
                if abs(pred_mean - 0.5) > 0.3:
                    self.console.print(f"  [yellow]INFO: Predictions shifted from 0.5 - check initialization[/yellow]")
                
                # Per-dimension variance check
                per_dim_std = preds.std(axis=0)
                collapsed_dims = np.sum(per_dim_std < 0.01)
                if collapsed_dims > 0:
                    self.console.print(f"  [red]WARNING: {collapsed_dims}/19 dimensions have std < 0.01 (collapsed)[/red]")
                
                self.prediction_variances.append(per_dim_std)
                
        except Exception as e:
            self.console.print(f"  [red]Error in diagnostics: {e}[/red]")
        finally:
            pl_module.train()
    
    def on_train_epoch_end(self, trainer, pl_module):
        epoch_time = time.time() - self.epoch_start_time
        avg_loss = sum(self.train_losses) / len(self.train_losses) if self.train_losses else 0
        
        self.console.print(f"\n[green]Train epoch complete[/green] | "
                          f"Avg Loss: {avg_loss:.4f} | "
                          f"Time: {epoch_time:.1f}s")
    
    def on_validation_epoch_start(self, trainer, pl_module):
        self.console.print(f"\n[yellow]Running validation...[/yellow]")
        self.val_start_time = time.time()
    
    def on_validation_epoch_end(self, trainer, pl_module):
        val_time = time.time() - self.val_start_time
        
        # Get validation metrics
        metrics = trainer.callback_metrics
        mean_r2 = metrics.get('val/mean_r2', None)
        val_loss = metrics.get('val/loss', None)
        
        # Print validation summary
        self.console.print(f"[green]Validation complete[/green] ({val_time:.1f}s)")
        if val_loss is not None:
            self.console.print(f"  Val Loss: {float(val_loss):.4f}")
        if mean_r2 is not None:
            r2_val = float(mean_r2)
            self.console.print(f"  [bold]Mean R2: {r2_val:.4f}[/bold]")
            
            # DIAGNOSTIC: Flag if R2 is negative
            if r2_val < 0:
                self.console.print(f"  [red]WARNING: R2 < 0 means model is worse than predicting mean![/red]")
                self.console.print(f"  [yellow]Check: 1) Data quality 2) Label scale 3) Note location format[/yellow]")
        
        # Collect per-dimension R2 values
        dim_r2s = {}
        for key, value in metrics.items():
            if key.startswith('val/') and key.endswith('_r2') and key != 'val/mean_r2':
                dim_name = key.replace('val/', '').replace('_r2', '')
                dim_r2s[dim_name] = float(value)
        
        if dim_r2s:
            sorted_dims = sorted(dim_r2s.items(), key=lambda x: x[1], reverse=True)
            
            # Count negative R2 dimensions
            negative_dims = sum(1 for _, r2 in sorted_dims if r2 < 0)
            if negative_dims > 10:
                self.console.print(f"\n  [red]WARNING: {negative_dims}/19 dimensions have R2 < 0![/red]")
            
            # Show top 5
            self.console.print(f"\n  [cyan]Top 5 dimensions:[/cyan]")
            for dim, r2 in sorted_dims[:5]:
                bar = '#' * int(max(0, r2) * 20)  # Visual bar
                color = "green" if r2 > 0 else "red"
                self.console.print(f"    {dim:20s}: [{color}]{r2:+.4f}[/{color}] {bar}")
            
            # Show bottom 3 if we have enough dimensions
            if len(sorted_dims) > 8:
                self.console.print(f"\n  [cyan]Bottom 3 dimensions:[/cyan]")
                for dim, r2 in sorted_dims[-3:]:
                    bar = '#' * int(max(0, r2) * 20)
                    color = "green" if r2 > 0 else "red"
                    self.console.print(f"    {dim:20s}: [{color}]{r2:+.4f}[/{color}] {bar}")
        
        # Best model tracking
        if hasattr(trainer, 'checkpoint_callback') and trainer.checkpoint_callback is not None:
            best_r2 = trainer.checkpoint_callback.best_model_score
            if best_r2 is not None:
                self.console.print(f"\n  [bold magenta]Best R2 so far: {float(best_r2):.4f}[/bold magenta]")

# Checkpoint callback - monitor mean R-squared
checkpoint_callback = ModelCheckpoint(
    dirpath=CONFIG['checkpoint_dir'],
    filename='percepiano_replica-{epoch:02d}-{val_mean_r2:.4f}',
    monitor='val/mean_r2',
    mode='max',
    save_top_k=3,
    save_last=True,
)

# Early stopping
early_stopping = EarlyStopping(
    monitor='val/mean_r2',
    patience=CONFIG['early_stopping_patience'],
    mode='max',
    verbose=True,
)

# LR monitor
lr_monitor = LearningRateMonitor(logging_interval='step')

# Custom progress callback
training_progress = TrainingProgressCallback(log_every_n_batches=10)

# Logger
logger = TensorBoardLogger(
    save_dir='/tmp/logs',
    name='percepiano_replica',
)

# Trainer with gradient clipping matched to original PercePiano (parser.py:159)
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu',
    devices=1,
    precision=CONFIG['precision'],
    gradient_clip_val=CONFIG['gradient_clip_val'],  # 2.0 (original PercePiano parser.py:159)
    callbacks=[checkpoint_callback, early_stopping, lr_monitor, training_progress],
    logger=logger,
    log_every_n_steps=10,
    val_check_interval=0.5,  # Validate twice per epoch
    enable_progress_bar=True,
)

print("="*60)
print("TRAINER CONFIGURATION")
print("="*60)
print(f"  Precision: {CONFIG['precision']}")
print(f"  Max epochs: {CONFIG['max_epochs']}")
print(f"  Batch size: {CONFIG['batch_size']} (original: 32)")
print(f"  Learning rate: {CONFIG['learning_rate']} (original: 1e-4)")
print(f"  Weight decay: {CONFIG['weight_decay']} (original: 1e-5)")
print(f"  Gradient clip: {CONFIG['gradient_clip_val']} (original: 2.0)")
print(f"  Early stopping patience: {CONFIG['early_stopping_patience']}")
print("="*60)

## Step 9: Train!

In [None]:
# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Train
print("="*70)
print("STARTING TRAINING - PercePiano SOTA Replica")
print("="*70)
print("\nKey metrics to watch:")
print("  - val/mean_r2: Overall R-squared (target: 0.35-0.40)")
print("  - val/timing_r2: Timing dimension (should be highest)")
print("  - val/tempo_r2: Tempo dimension")
print("")
print("PercePiano SOTA baselines:")
print("  - Bi-LSTM: R^2 = 0.185")
print("  - MidiBERT: R^2 = 0.313")
print("  - Bi-LSTM + SA + HAN: R^2 = 0.397 (SOTA)")
print("="*70)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Sync checkpoints to Google Drive
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 10: Evaluation

In [None]:
# Test with best checkpoint
print("Running test with best checkpoint...")
best_path = checkpoint_callback.best_model_path
print(f"Best checkpoint: {best_path}")

if best_path:
    test_results = trainer.test(model, test_loader, ckpt_path=best_path)
    print("\nTest Results:")
    for k, v in test_results[0].items():
        print(f"  {k}: {v:.4f}")

In [None]:
import torch
import numpy as np

# Load best model
from src.models.percepiano_replica import PercePianoVNetModule
best_model = PercePianoVNetModule.load_from_checkpoint(checkpoint_callback.best_model_path)
best_model.eval()
best_model.cuda()

# Collect predictions on test set
all_preds = []
all_targets = []

print("Collecting predictions on test set...")
with torch.no_grad():
    for batch in test_loader:
        # Move batch to GPU
        input_features = batch['input_features'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        scores = batch['scores'].cuda()
        
        note_locations = {
            'beat': batch['note_locations_beat'].cuda(),
            'measure': batch['note_locations_measure'].cuda(),
            'voice': batch['note_locations_voice'].cuda(),
        }
        
        # Forward pass with VirtuosoNet features
        outputs = best_model(
            input_features=input_features,
            note_locations=note_locations,
            attention_mask=attention_mask,
        )
        
        all_preds.append(outputs['predictions'].cpu())
        all_targets.append(scores.cpu())

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()
dimensions = best_model.dimensions

print(f"Collected {len(all_preds)} test samples across {len(dimensions)} dimensions")

In [None]:
from src.evaluation import (
    compute_all_metrics,
    PerDimensionAnalysis,
    compare_to_sota,
    format_comparison_table,
    create_results_table,
    PERCEPIANO_BASELINES,
)

# Compute all metrics
metrics = compute_all_metrics(
    predictions=all_preds,
    targets=all_targets,
    dimension_names=list(dimensions),
)

# Print results table
print(create_results_table(metrics))

In [None]:
# Compare to SOTA baselines
our_r2 = metrics['r2'].value
per_dim_r2 = metrics['r2'].per_dimension

comparison = compare_to_sota(
    model_r2=our_r2,
    model_name="PercePiano Replica (CrescendAI)",
    split_type="piece",
    per_dimension_r2=per_dim_r2,
)

print(format_comparison_table(comparison))

In [None]:
# Summary
print("="*70)
print("PERCEPIANO REPLICA - RESULTS SUMMARY")
print("="*70)

print(f"\n1. OVERALL PERFORMANCE")
print(f"   Mean R^2: {our_r2:.4f}")
print(f"   Target (0.35-0.40): {'ACHIEVED' if our_r2 >= 0.35 else 'CLOSE' if our_r2 >= 0.30 else 'NOT YET'}")

print(f"\n2. COMPARISON TO PUBLISHED BASELINES")
print(f"   Bi-LSTM baseline: 0.185")
print(f"   MidiBERT: 0.313")
print(f"   Bi-LSTM + SA + HAN (SOTA): 0.397")
print(f"   Our replica: {our_r2:.4f}")

print(f"\n3. MODEL SIZE")
print(f"   Parameters: {best_model.count_parameters():,}")
print(f"   vs Previous (51.5M): {51_500_000 / best_model.count_parameters():.1f}x smaller")

print(f"\n4. TOP 5 DIMENSIONS")
sorted_dims = sorted(per_dim_r2.items(), key=lambda x: x[1], reverse=True)
for dim, r2 in sorted_dims[:5]:
    print(f"   {dim}: {r2:.4f}")

print(f"\n5. BOTTOM 5 DIMENSIONS (need improvement)")
for dim, r2 in sorted_dims[-5:]:
    print(f"   {dim}: {r2:.4f}")

print("="*70)

## Step 11: Save as Teacher Model

If the model achieves R-squared >= 0.30, save it as a **Teacher Model** for pseudo-labeling MAESTRO.

In [None]:
import torch
from pathlib import Path

# Save as teacher model
teacher_path = Path(CONFIG['checkpoint_dir']) / 'percepiano_teacher.pt'

if our_r2 >= 0.25:  # Minimum threshold for useful teacher
    torch.save({
        'state_dict': best_model.state_dict(),
        'config': {
            'input_size': CONFIG['input_size'],  # 84 (79 base + 5 unnorm)
            'hidden_size': CONFIG['hidden_size'],
            'note_layers': CONFIG['note_layers'],
            'voice_layers': CONFIG['voice_layers'],
            'beat_layers': CONFIG['beat_layers'],
            'measure_layers': CONFIG['measure_layers'],
            'num_attention_heads': CONFIG['num_attention_heads'],
            'final_hidden': CONFIG['final_hidden'],
            'dropout': CONFIG['dropout'],
        },
        'dimensions': list(dimensions),
        'metrics': {
            'r2': our_r2,
            'per_dimension_r2': per_dim_r2,
        },
        'sota_comparison': {
            'rank': comparison['rank'],
            'total_baselines': comparison['total_baselines'],
            'vs_best_baseline': comparison['improvement_vs_best'],
        },
        'architecture': 'PercePiano Replica (Bi-LSTM + HAN) with VirtuosoNet 84-dim features (79 base + 5 unnorm)',
        'reference': 'https://github.com/JonghoKimSNU/PercePiano',
    }, teacher_path)
    
    print(f"Saved teacher model to {teacher_path}")
    print(f"Teacher R^2: {our_r2:.4f}")
    print(f"\nThis model can be used for pseudo-labeling MAESTRO!")
    print(f"Run: python scripts/pseudo_label_maestro.py --teacher {teacher_path}")
else:
    print(f"R^2 = {our_r2:.4f} is below threshold (0.25) for teacher model.")
    print("Consider:")
    print("  1. Training for more epochs")
    print("  2. Adjusting hyperparameters")
    print("  3. Checking data quality")

# Final sync to Google Drive
if RCLONE_AVAILABLE:
    print("\nFinal sync to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")
    print(f"Checkpoints available at: {CONFIG['gdrive_checkpoint']}")

## Next Steps

If R-squared >= 0.30:

1. **Pseudo-label MAESTRO**: Use this teacher model to generate labels for MAESTRO dataset
2. **Train larger model**: With expanded dataset (~6000 samples), train a larger model
3. **Noisy Student**: Apply noisy student training for potential improvement over teacher

If R-squared < 0.30:

1. Check if validation set is too small (only 27 samples)
2. Consider k-fold cross-validation for more robust estimates
3. Verify data preprocessing matches PercePiano exactly

---

**Attribution**: This model replicates the architecture from PercePiano (Park et al., ISMIR/Nature 2024).  
GitHub: https://github.com/JonghoKimSNU/PercePiano