# PercePiano SOTA Replica Training - VirtuosoNet Features

**Goal**: Replicate PercePiano's SOTA results (R-squared = 0.35-0.40) using the exact VirtuosoNet 78-dim features.

## Attribution

This notebook implements the architecture from:

> **PercePiano: A Benchmark for Perceptual Evaluation of Piano Performance**  
> Park, Jongho and Kim, Dasaem et al.  
> ISMIR 2024 / Nature Scientific Reports 2024  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Key Changes from Previous Version

| Aspect | Previous (Broken) | This Version (Fixed) |
|--------|-------------------|---------------------|
| Input Features | 20 custom + 256 global concat | **78-dim VirtuosoNet features** |
| HAN Input | 276-dim (wrong) | **78-dim (correct)** |
| Global Context | Concatenated to every note | **Not used (matches PercePiano)** |
| Feature Source | Custom score alignment | **VirtuosoNet preprocessing** |

## What This Notebook Does

1. Downloads PercePiano data and scores from Google Drive
2. Runs VirtuosoNet preprocessing to extract 78-dim features
3. Trains the faithful PercePiano replica with correct features
4. Evaluates against SOTA baselines
5. Saves the trained model as a **Teacher Model** for pseudo-labeling MAESTRO

## Expected Results

- **Target R-squared**: 0.35-0.40 (piece-split)
- **Training time**: ~1-2 hours on T4/A100
- **Model size**: ~8-10M parameters

## Step 1: Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow.")

In [None]:
# Install rclone for Google Drive sync
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

In [None]:
import os
from pathlib import Path
import subprocess

# Paths
CHECKPOINT_ROOT = '/tmp/checkpoints/percepiano_replica'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_replica'
GDRIVE_DATA_PATH = 'gdrive:percepiano_data'
DATA_ROOT = Path('/tmp/percepiano_data')

print("="*70)
print("PERCEPIANO REPLICA - SOTA REPRODUCTION")
print("="*70)
print("Reference: Park et al., 'PercePiano', ISMIR/Nature 2024")
print("GitHub: https://github.com/JonghoKimSNU/PercePiano")
print("="*70)

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
print("\nChecking rclone configuration...")
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("  rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
    
    # Restore existing checkpoints
    print("\nRestoring checkpoints from Google Drive (if any)...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
        capture_output=False
    )
else:
    print("  rclone 'gdrive' remote: NOT CONFIGURED")
    print("  Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nCheckpoint directory: {CHECKPOINT_ROOT}")
print(f"rclone available: {RCLONE_AVAILABLE}")

## Step 2: Download PercePiano Data

In [None]:
from pathlib import Path
import subprocess
import json

DATA_ROOT = Path('/tmp/percepiano_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Download PercePiano data
train_file = DATA_ROOT / 'percepiano_train.json'
if train_file.exists():
    print(f"Data already exists at {DATA_ROOT}")
else:
    print("Downloading PercePiano data from Google Drive...")
    result = subprocess.run(
        ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
        capture_output=False
    )

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

for split in ['train', 'val', 'test']:
    path = DATA_ROOT / f'percepiano_{split}.json'
    if path.exists():
        with open(path) as f:
            data = json.load(f)
        has_scores = sum(1 for s in data if s.get('score_path'))
        print(f"{split}: {len(data)} samples ({has_scores} with score paths)")
    else:
        print(f"ERROR: {path} not found!")

# Check MIDI and score files
midi_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'all_2rounds'
score_dir = DATA_ROOT / 'PercePiano' / 'virtuoso' / 'data' / 'score_xml'

if midi_dir.exists():
    midi_files = list(midi_dir.glob('*.mid'))
    print(f"\nMIDI files: {len(midi_files)}")
else:
    raise FileNotFoundError(f"MIDI directory not found at {midi_dir}")

if score_dir.exists():
    score_files = list(score_dir.glob('*.musicxml'))
    print(f"Score files: {len(score_files)}")
    if len(score_files) == 0:
        raise FileNotFoundError(f"No MusicXML files found in {score_dir}")
else:
    raise FileNotFoundError(
        f"Score directory not found at {score_dir}\n"
        "Run: rclone copy gdrive:percepiano_data/PercePiano/virtuoso/data/score_xml/ {score_dir}/"
    )

# Store paths for later
MIDI_DIR = midi_dir
SCORE_DIR = score_dir

print("\n[OK] All required files present")

## Step 3: Update Paths for Thunder Compute

In [None]:
import json
from pathlib import Path

# All 19 PercePiano dimensions
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity", "timbre_variety",
    "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

def update_paths_for_thunder(data_root: Path):
    """Update paths in JSON files for Thunder Compute environment."""
    
    for split in ['train', 'val', 'test']:
        path = data_root / f'percepiano_{split}.json'
        
        with open(path) as f:
            data = json.load(f)
        
        for sample in data:
            # Update MIDI path
            filename = Path(sample['midi_path']).name
            sample['midi_path'] = str(MIDI_DIR / filename)
            
            # Make sure scores dict uses all 19 dimensions
            if 'percepiano_scores' in sample:
                pp_scores = sample['percepiano_scores'][:19]
                sample['scores'] = {
                    dim: pp_scores[i]
                    for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
                }
        
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
        
        print(f"Updated {split}: {len(data)} samples")

update_paths_for_thunder(DATA_ROOT)

# Verify
with open(DATA_ROOT / 'percepiano_train.json') as f:
    sample = json.load(f)[0]

print(f"\nSample MIDI path: {sample['midi_path']}")
print(f"Sample score path: {sample.get('score_path', 'N/A')}")
print(f"Dimensions: {len(sample['scores'])}")
print(f"MIDI exists: {Path(sample['midi_path']).exists()}")

## Step 4: Pre-Flight Validation

In [None]:
# Pre-flight validation - FAIL FAST if requirements not met
from src.utils.preflight_validation import run_preflight_validation, PreflightValidationError

print("="*60)
print("PRE-FLIGHT VALIDATION")
print("="*60)

try:
    run_preflight_validation(
        data_dir=DATA_ROOT,
        score_dir=SCORE_DIR,
        pretrained_checkpoint=None,  # PercePiano replica doesn't use pre-trained encoder
        require_pretrained=False,    # Training from scratch with Bi-LSTM
        min_score_coverage=0.95,
    )
    print("\n[OK] Pre-flight validation PASSED - ready to train")
except PreflightValidationError as e:
    print(f"\n[VALIDATION FAILED]\n{e}")
    raise RuntimeError("Fix the issues above before training")

## Step 5: Training Configuration

Configuration matched exactly to PercePiano paper's `han_bigger256_concat.yml`:

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# PercePiano SOTA Configuration
# From: https://github.com/JonghoKimSNU/PercePiano
# File: virtuoso/ymls/shared/label19/han_bigger256_concat.yml
CONFIG = {
    # Data
    'data_dir': str(DATA_ROOT),
    'score_dir': str(SCORE_DIR),
    'vnet_data_dir': str(DATA_ROOT / 'percepiano_vnet'),  # VirtuosoNet preprocessed features
    
    # VirtuosoNet Feature Dimension (CRITICAL FIX)
    'input_size': 78,  # VirtuosoNet 78-dim features (NOT 20 + 256 = 276!)
    
    # HAN Architecture (matched to PercePiano)
    'hidden_size': 256,        # PercePiano: 256 for all levels
    'note_layers': 2,          # PercePiano: 2
    'voice_layers': 2,         # PercePiano: 2
    'beat_layers': 2,          # PercePiano: 2
    'measure_layers': 1,       # PercePiano: 1
    'num_attention_heads': 8,  # PercePiano: 8
    'final_hidden': 128,       # PercePiano: 128
    
    # Training (matched to PercePiano)
    'learning_rate': 2.5e-5,   # PercePiano: 2.5e-5
    'weight_decay': 0.01,
    'dropout': 0.2,            # PercePiano: 0.2
    'batch_size': 8,           # PercePiano: 8
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 1.0,
    'precision': '16-mixed',
    
    # Dataset
    'max_notes': 1024,
    
    # Checkpoints
    'checkpoint_dir': CHECKPOINT_ROOT,
    'gdrive_checkpoint': GDRIVE_CHECKPOINT_PATH,
}

print("="*70)
print("PERCEPIANO REPLICA CONFIGURATION (FIXED)")
print("="*70)
print("Reference: han_bigger256_concat.yml from PercePiano repo")
print("="*70)
print("\nCRITICAL FIX: Using VirtuosoNet 78-dim features instead of 20+256 concat")
print("="*70)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")
print("="*70)
print("\nKey fixes from broken version:")
print("  - Input features: 78-dim VirtuosoNet (was 20 + 256 = 276)")
print("  - No global context concatenation (matches PercePiano exactly)")
print("  - Using VirtuosoNet preprocessing for features")

## Step 6: VirtuosoNet Preprocessing (CRITICAL)

This step extracts the exact 78-dim features used in the original PercePiano paper. If preprocessing has already been done, this cell will skip.

In [None]:
from pathlib import Path
import subprocess
import shutil
import os

# First, ensure VirtuosoNet modules are available by cloning PercePiano repo
virtuoso_module_path = Path('/tmp/crescendai/model/data/raw/PercePiano/virtuoso/virtuoso/pyScoreParser')
percepiano_path = Path('/tmp/crescendai/model/data/raw/PercePiano')

if not (virtuoso_module_path / 'feature_extraction.py').exists():
    print("VirtuosoNet modules not found. Cloning PercePiano repository...")
    
    # Remove existing directory if it exists (might be empty or corrupted)
    if percepiano_path.exists():
        shutil.rmtree(percepiano_path)
    
    percepiano_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Clone the repo
    result = subprocess.run(
        ['git', 'clone', '--depth', '1',
         'https://github.com/JonghoKimSNU/PercePiano.git',
         str(percepiano_path)],
        capture_output=True,
        text=True
    )
    
    if result.returncode != 0:
        print(f"Failed to clone PercePiano: {result.stderr}")
        raise RuntimeError("Could not clone PercePiano repository")
    
    print("PercePiano cloned successfully!")
    print(f"VirtuosoNet modules at: {virtuoso_module_path}")
else:
    print(f"VirtuosoNet modules already available at {virtuoso_module_path}")

# Check if VirtuosoNet features exist
vnet_dir = Path(CONFIG['vnet_data_dir'])
vnet_train_dir = vnet_dir / 'train'

if vnet_train_dir.exists() and list(vnet_train_dir.glob('*.pkl')):
    print(f"\nVirtuosoNet features already exist at {vnet_dir}")
    print(f"  train: {len(list((vnet_dir / 'train').glob('*.pkl')))} samples")
    print(f"  val: {len(list((vnet_dir / 'val').glob('*.pkl')))} samples")
    print(f"  test: {len(list((vnet_dir / 'test').glob('*.pkl')))} samples")
else:
    print("\nVirtuosoNet features not found. Running preprocessing...")
    print("This extracts the exact 78-dim features used in the original PercePiano paper.")
    print("")
    
    # Run preprocessing script
    # Pass DATA_ROOT directly - the script will find JSON files and score_xml there
    result = subprocess.run(
        ['python', 'scripts/preprocess_percepiano_vnet.py',
         '--data_root', str(DATA_ROOT),
         '--output_dir', str(vnet_dir)],
        cwd='/tmp/crescendai/model',
        capture_output=True,
        text=True
    )
    
    print(result.stdout)
    if result.returncode != 0:
        print(f"Preprocessing failed: {result.stderr}")
        print("\nNOTE: VirtuosoNet preprocessing requires MusicXML parsing.")
        print("If this fails, you may need to manually align MIDI files to scores.")
        print("")
        print("Alternative: Download pre-processed features from Google Drive:")
        print(f"  rclone copy gdrive:percepiano_data/percepiano_vnet {vnet_dir}")
    else:
        print("Preprocessing complete!")

# Verify features exist
if vnet_train_dir.exists():
    num_train = len(list((vnet_dir / 'train').glob('*.pkl')))
    print(f"\nVirtuosoNet features ready: {num_train} training samples")
else:
    raise RuntimeError(
        f"VirtuosoNet features not available at {vnet_dir}.\n"
        "Run preprocessing or download from Google Drive."
    )

In [None]:
# Patch PercePiano data_class.py to fix silent failures
# The original code has bare `except:` clauses that swallow errors silently
# This patch replaces them with explicit exception handling

import re
from pathlib import Path

data_class_path = Path('/tmp/crescendai/model/data/raw/PercePiano/virtuoso/virtuoso/pyScoreParser/data_class.py')

if data_class_path.exists():
    print("Patching PercePiano data_class.py to fix silent failures...")
    
    content = data_class_path.read_text()
    original_content = content
    
    # Patch 1: Fix bare except in load_all_piece (line ~95-98)
    # Change from printing error to re-raising with context
    old_pattern1 = r"except Exception as ex:\s+# TODO: TGK: this is ambiguous.*?\s+print\(f'Error while processing \{scores\[n\]\}\. Error type :\{ex\}'\)"
    new_code1 = """except Exception as ex:
                # Re-raise with full context instead of silently continuing
                raise RuntimeError(f'Error loading piece {scores[n]}: {type(ex).__name__}: {ex}') from ex"""
    content = re.sub(old_pattern1, new_code1, content, flags=re.DOTALL)
    
    # Patch 2: Fix bare except in performance alignment (line ~332-335)
    old_pattern2 = r"except:\s+perform_data = None\s+print\(f'Cannot align \{perform\}'\)\s+self\.performances\.append\(None\)"
    new_code2 = """except (ValueError, IndexError, FileNotFoundError, OSError) as e:
                        # Explicit exception handling - re-raise with context
                        raise RuntimeError(f'Alignment failed for {perform}: {type(e).__name__}: {e}') from e"""
    content = re.sub(old_pattern2, new_code2, content)
    
    # Patch 3: Fix bare except in Nakamura alignment (line ~425)
    old_pattern3 = r"except:\s+print\('Error to process \{\}'\.format\(midi_file_path\)\)"
    new_code3 = """except subprocess.CalledProcessError as e:
            print(f'Alignment tool failed for {midi_file_path}: {e}')"""
    content = re.sub(old_pattern3, new_code3, content)
    
    # Patch 4: Fix second bare except in Nakamura alignment retry (line ~435)
    old_pattern4 = r"except:\s+align_success = False\s+print\('Fail to process \{\}'\.format\(midi_file_path\)\)\s+os\.chdir\(current_dir\)"
    new_code4 = """except subprocess.CalledProcessError as e2:
                align_success = False
                raise RuntimeError(f'Alignment tool failed after retry for {midi_file_path}: {e2}') from e2"""
    content = re.sub(old_pattern4, new_code4, content)
    
    if content != original_content:
        data_class_path.write_text(content)
        print("[OK] Patched data_class.py - silent failures now raise explicit exceptions")
    else:
        print("[INFO] data_class.py already patched or patterns not found")
else:
    print("[SKIP] data_class.py not found yet - will be available after cloning")

## Step 7: Create DataLoaders and Model

Using the VirtuosoNet preprocessed features (78-dim) with the faithful PercePiano replica architecture.

In [None]:
from src.data.percepiano_vnet_dataset import create_vnet_dataloaders
from src.models.percepiano_replica import PercePianoVNetModule

# Create DataLoaders with VirtuosoNet features
train_loader, val_loader, test_loader = create_vnet_dataloaders(
    data_dir=CONFIG['vnet_data_dir'],
    batch_size=CONFIG['batch_size'],
    max_notes=CONFIG['max_notes'],
    num_workers=4,
)

print(f"Train: {len(train_loader.dataset)} samples")
print(f"Val: {len(val_loader.dataset)} samples")
print(f"Test: {len(test_loader.dataset)} samples")

# Create model with VirtuosoNet features (78-dim input)
model = PercePianoVNetModule(
    # VirtuosoNet input dimension
    input_size=CONFIG['input_size'],  # 78-dim features
    # HAN dimensions (matched to PercePiano)
    hidden_size=CONFIG['hidden_size'],
    note_layers=CONFIG['note_layers'],
    voice_layers=CONFIG['voice_layers'],
    beat_layers=CONFIG['beat_layers'],
    measure_layers=CONFIG['measure_layers'],
    num_attention_heads=CONFIG['num_attention_heads'],
    final_hidden=CONFIG['final_hidden'],
    # Training
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    dropout=CONFIG['dropout'],
)

# Count parameters
total_params = model.count_parameters()

print("\n" + "="*70)
print("PERCEPIANO REPLICA MODEL (VirtuosoNet Features)")
print("="*70)
print(f"Architecture: Bi-LSTM + HAN (Note -> Voice -> Beat -> Measure)")
print(f"Input features: {CONFIG['input_size']} (VirtuosoNet)")
print(f"Hidden size: {CONFIG['hidden_size']}")
print(f"Total parameters: {total_params:,}")
print(f"")
print(f"Key fix: Input is 78-dim VirtuosoNet features directly")
print(f"         (NOT 20 + 256 global context = 276 like before)")
print(f"")
print(f"Target R-squared: 0.35-0.40 (piece-split)")
print(f"Dimensions: {len(model.dimensions)}")
print("="*70)

## Step 8: Configure Trainer

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Checkpoint callback - monitor mean R-squared
checkpoint_callback = ModelCheckpoint(
    dirpath=CONFIG['checkpoint_dir'],
    filename='percepiano_replica-{epoch:02d}-{val_mean_r2:.4f}',
    monitor='val/mean_r2',
    mode='max',
    save_top_k=3,
    save_last=True,
)

# Early stopping
early_stopping = EarlyStopping(
    monitor='val/mean_r2',
    patience=CONFIG['early_stopping_patience'],
    mode='max',
)

# LR monitor
lr_monitor = LearningRateMonitor(logging_interval='step')

# Logger
logger = TensorBoardLogger(
    save_dir='/tmp/logs',
    name='percepiano_replica',
)

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu',
    devices=1,
    precision=CONFIG['precision'],
    gradient_clip_val=CONFIG['gradient_clip_val'],
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    logger=logger,
    log_every_n_steps=10,
    val_check_interval=0.5,  # Validate twice per epoch
)

print("Trainer configured!")
print(f"  Precision: {CONFIG['precision']}")
print(f"  Max epochs: {CONFIG['max_epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Early stopping patience: {CONFIG['early_stopping_patience']}")

## Step 9: Train!

In [None]:
# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Train
print("="*70)
print("STARTING TRAINING - PercePiano SOTA Replica")
print("="*70)
print("\nKey metrics to watch:")
print("  - val/mean_r2: Overall R-squared (target: 0.35-0.40)")
print("  - val/timing_r2: Timing dimension (should be highest)")
print("  - val/tempo_r2: Tempo dimension")
print("")
print("PercePiano SOTA baselines:")
print("  - Bi-LSTM: R^2 = 0.185")
print("  - MidiBERT: R^2 = 0.313")
print("  - Bi-LSTM + SA + HAN: R^2 = 0.397 (SOTA)")
print("="*70)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Sync checkpoints to Google Drive
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 10: Evaluation

In [None]:
# Test with best checkpoint
print("Running test with best checkpoint...")
best_path = checkpoint_callback.best_model_path
print(f"Best checkpoint: {best_path}")

if best_path:
    test_results = trainer.test(model, test_loader, ckpt_path=best_path)
    print("\nTest Results:")
    for k, v in test_results[0].items():
        print(f"  {k}: {v:.4f}")

In [None]:
import torch
import numpy as np

# Load best model
from src.models.percepiano_replica import PercePianoVNetModule
best_model = PercePianoVNetModule.load_from_checkpoint(checkpoint_callback.best_model_path)
best_model.eval()
best_model.cuda()

# Collect predictions on test set
all_preds = []
all_targets = []

print("Collecting predictions on test set...")
with torch.no_grad():
    for batch in test_loader:
        # Move batch to GPU
        input_features = batch['input_features'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        scores = batch['scores'].cuda()
        
        note_locations = {
            'beat': batch['note_locations_beat'].cuda(),
            'measure': batch['note_locations_measure'].cuda(),
            'voice': batch['note_locations_voice'].cuda(),
        }
        
        # Forward pass with VirtuosoNet features
        outputs = best_model(
            input_features=input_features,
            note_locations=note_locations,
            attention_mask=attention_mask,
        )
        
        all_preds.append(outputs['predictions'].cpu())
        all_targets.append(scores.cpu())

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()
dimensions = best_model.dimensions

print(f"Collected {len(all_preds)} test samples across {len(dimensions)} dimensions")

In [None]:
from src.evaluation import (
    compute_all_metrics,
    PerDimensionAnalysis,
    compare_to_sota,
    format_comparison_table,
    create_results_table,
    PERCEPIANO_BASELINES,
)

# Compute all metrics
metrics = compute_all_metrics(
    predictions=all_preds,
    targets=all_targets,
    dimension_names=list(dimensions),
)

# Print results table
print(create_results_table(metrics))

In [None]:
# Compare to SOTA baselines
our_r2 = metrics['r2'].value
per_dim_r2 = metrics['r2'].per_dimension

comparison = compare_to_sota(
    model_r2=our_r2,
    model_name="PercePiano Replica (CrescendAI)",
    split_type="piece",
    per_dimension_r2=per_dim_r2,
)

print(format_comparison_table(comparison))

In [None]:
# Summary
print("="*70)
print("PERCEPIANO REPLICA - RESULTS SUMMARY")
print("="*70)

print(f"\n1. OVERALL PERFORMANCE")
print(f"   Mean R^2: {our_r2:.4f}")
print(f"   Target (0.35-0.40): {'ACHIEVED' if our_r2 >= 0.35 else 'CLOSE' if our_r2 >= 0.30 else 'NOT YET'}")

print(f"\n2. COMPARISON TO PUBLISHED BASELINES")
print(f"   Bi-LSTM baseline: 0.185")
print(f"   MidiBERT: 0.313")
print(f"   Bi-LSTM + SA + HAN (SOTA): 0.397")
print(f"   Our replica: {our_r2:.4f}")

print(f"\n3. MODEL SIZE")
print(f"   Parameters: {best_model.count_parameters():,}")
print(f"   vs Previous (51.5M): {51_500_000 / best_model.count_parameters():.1f}x smaller")

print(f"\n4. TOP 5 DIMENSIONS")
sorted_dims = sorted(per_dim_r2.items(), key=lambda x: x[1], reverse=True)
for dim, r2 in sorted_dims[:5]:
    print(f"   {dim}: {r2:.4f}")

print(f"\n5. BOTTOM 5 DIMENSIONS (need improvement)")
for dim, r2 in sorted_dims[-5:]:
    print(f"   {dim}: {r2:.4f}")

print("="*70)

## Step 11: Save as Teacher Model

If the model achieves R-squared >= 0.30, save it as a **Teacher Model** for pseudo-labeling MAESTRO.

In [None]:
import torch
from pathlib import Path

# Save as teacher model
teacher_path = Path(CONFIG['checkpoint_dir']) / 'percepiano_teacher.pt'

if our_r2 >= 0.25:  # Minimum threshold for useful teacher
    torch.save({
        'state_dict': best_model.state_dict(),
        'config': {
            'input_size': CONFIG['input_size'],
            'hidden_size': CONFIG['hidden_size'],
            'note_layers': CONFIG['note_layers'],
            'voice_layers': CONFIG['voice_layers'],
            'beat_layers': CONFIG['beat_layers'],
            'measure_layers': CONFIG['measure_layers'],
            'num_attention_heads': CONFIG['num_attention_heads'],
            'final_hidden': CONFIG['final_hidden'],
            'dropout': CONFIG['dropout'],
        },
        'dimensions': list(dimensions),
        'metrics': {
            'r2': our_r2,
            'per_dimension_r2': per_dim_r2,
        },
        'sota_comparison': {
            'rank': comparison['rank'],
            'total_baselines': comparison['total_baselines'],
            'vs_best_baseline': comparison['improvement_vs_best'],
        },
        'architecture': 'PercePiano Replica (Bi-LSTM + HAN) with VirtuosoNet 78-dim features',
        'reference': 'https://github.com/JonghoKimSNU/PercePiano',
    }, teacher_path)
    
    print(f"Saved teacher model to {teacher_path}")
    print(f"Teacher R^2: {our_r2:.4f}")
    print(f"\nThis model can be used for pseudo-labeling MAESTRO!")
    print(f"Run: python scripts/pseudo_label_maestro.py --teacher {teacher_path}")
else:
    print(f"R^2 = {our_r2:.4f} is below threshold (0.25) for teacher model.")
    print("Consider:")
    print("  1. Training for more epochs")
    print("  2. Adjusting hyperparameters")
    print("  3. Checking data quality")

# Final sync to Google Drive
if RCLONE_AVAILABLE:
    print("\nFinal sync to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], CONFIG['gdrive_checkpoint'], '--progress'],
        capture_output=False
    )
    print("Sync complete!")
    print(f"Checkpoints available at: {CONFIG['gdrive_checkpoint']}")

## Next Steps

If R-squared >= 0.30:

1. **Pseudo-label MAESTRO**: Use this teacher model to generate labels for MAESTRO dataset
2. **Train larger model**: With expanded dataset (~6000 samples), train a larger model
3. **Noisy Student**: Apply noisy student training for potential improvement over teacher

If R-squared < 0.30:

1. Check if validation set is too small (only 27 samples)
2. Consider k-fold cross-validation for more robust estimates
3. Verify data preprocessing matches PercePiano exactly

---

**Attribution**: This model replicates the architecture from PercePiano (Park et al., ISMIR/Nature 2024).  
GitHub: https://github.com/JonghoKimSNU/PercePiano