# PercePiano Replica Training (4-Fold Cross-Validation)

Train the PercePiano replica model using 4-fold piece-based cross-validation,
matching the methodology from the PercePiano paper.

## Attribution

> **PercePiano: A Benchmark for Perceptual Evaluation of Piano Performance**  
> Park, Jongho and Kim, Dasaem et al.  
> ISMIR 2024 / Nature Scientific Reports 2024  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Methodology

- **Piece-based splits**: All performances of the same piece stay in the same fold
- **4-fold CV**: Each fold takes turns as validation set
- **Test set**: 15% of pieces held out for final evaluation
- **Per-fold normalization**: Stats computed from training folds only

## Expected Results

- Target R2: 0.35-0.40 (matching published SOTA)
- Training time: ~2-4 hours on T4/A100 (all 4 folds)

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_split')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_split'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

print("="*60)
print("PERCEPIANO REPLICA TRAINING (4-FOLD CV)")
print("="*60)

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("rclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")
print(f"Log directory: {LOG_ROOT}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

fold_file = DATA_ROOT / 'fold_assignments.json'
print(f"  fold_assignments.json: {'present' if fold_file.exists() else 'will be created'}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
TEST_RATIO = 0.15
SEED = 42

print("="*60)
print("FOLD ASSIGNMENT CREATION")
print("="*60)

# Create fold assignments (or load existing)
if FOLD_FILE.exists():
    print(f"\nLoading existing fold assignments from {FOLD_FILE}")
    fold_assignments = load_fold_assignments(FOLD_FILE)
else:
    print(f"\nCreating new {N_FOLDS}-fold piece-based splits...")
    fold_assignments = create_piece_based_folds(
        data_dir=DATA_ROOT,
        n_folds=N_FOLDS,
        test_ratio=TEST_RATIO,
        seed=SEED,
    )
    save_fold_assignments(fold_assignments, FOLD_FILE)

# Print statistics
print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# PercePiano Configuration (matched to original paper)
CONFIG = {
    # K-Fold settings
    'n_folds': N_FOLDS,
    'test_ratio': TEST_RATIO,
    
    # Data
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    
    # Model input (79 normalized features, unnorm used for augmentation only)
    'input_size': 79,
    
    # HAN Architecture (han_bigger256_concat.yml)
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'final_hidden': 128,
    
    # Training (parser.py defaults)
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 32,
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '16-mixed',
    
    # Dataset
    'max_notes': 1024,
    'num_workers': 0,  # Avoid shared memory issues on Thunder Compute
    'augment_train': True,
}

print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Step 6: Initialize K-Fold Trainer

In [None]:
from src.percepiano.training.kfold_trainer import KFoldTrainer
import pytorch_lightning as pl

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Create K-Fold trainer
kfold_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
)

print("K-Fold Trainer initialized")
print(f"  Folds: {N_FOLDS}")
print(f"  Checkpoints: {CHECKPOINT_ROOT}")
print(f"  Logs: {LOG_ROOT}")

## Step 7: Train All Folds

In [None]:
print("="*60)
print("STARTING 4-FOLD CROSS-VALIDATION TRAINING")
print("="*60)
print("\nPercePiano SOTA baselines:")
print("  Bi-LSTM: R2 = 0.185")
print("  MidiBERT: R2 = 0.313")
print("  Bi-LSTM + SA + HAN: R2 = 0.397 (SOTA)")
print("="*60)

# Train all folds
aggregate_metrics = kfold_trainer.train_all_folds(verbose=True)

# Save results
kfold_trainer.save_results()

In [None]:
# Sync checkpoints to Google Drive after training
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 8: Test Set Evaluation

In [None]:
# Evaluate all fold models on held-out test set
test_results = kfold_trainer.evaluate_on_test(verbose=True)

## Step 9: Per-Fold Results Summary

In [None]:
import numpy as np
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

print("="*80)
print("PER-FOLD VALIDATION RESULTS")
print("="*80)

print(f"\n{'Fold':<6} {'Val R2':>10} {'Val Pearson':>12} {'Val MAE':>10} {'Val RMSE':>10} {'Epochs':>8} {'Time (s)':>10}")
print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")

for m in kfold_trainer.fold_metrics:
    print(f"{m.fold_id:<6} {m.val_r2:>+10.4f} {m.val_pearson:>+12.4f} {m.val_mae:>10.4f} {m.val_rmse:>10.4f} {m.epochs_trained:>8} {m.training_time_seconds:>10.1f}")

print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")
print(f"{'Mean':<6} {aggregate_metrics.mean_r2:>+10.4f} {aggregate_metrics.mean_pearson:>+12.4f} {aggregate_metrics.mean_mae:>10.4f} {aggregate_metrics.mean_rmse:>10.4f}")
print(f"{'Std':<6} {aggregate_metrics.std_r2:>+10.4f} {aggregate_metrics.std_pearson:>+12.4f} {aggregate_metrics.std_mae:>10.4f} {aggregate_metrics.std_rmse:>10.4f}")

## Step 10: Per-Dimension Analysis

In [None]:
print("="*80)
print("PER-DIMENSION R2 (Mean +/- Std across folds)")
print("="*80)

# Sort dimensions by mean R2
sorted_dims = sorted(
    aggregate_metrics.per_dim_mean_r2.items(),
    key=lambda x: x[1],
    reverse=True
)

print(f"\n{'Dimension':<25} {'Mean R2':>10} {'Std R2':>10} {'Status':<12}")
print(f"{'-'*25} {'-'*10} {'-'*10} {'-'*12}")

for dim, mean_r2 in sorted_dims:
    std_r2 = aggregate_metrics.per_dim_std_r2[dim]
    
    if mean_r2 >= 0.3:
        status = "[GOOD]"
    elif mean_r2 >= 0.1:
        status = "[OK]"
    elif mean_r2 >= 0:
        status = "[WEAK]"
    else:
        status = "[FAILED]"
    
    print(f"{dim:<25} {mean_r2:>+10.4f} {std_r2:>10.4f} {status:<12}")

# Summary
positive = sum(1 for d, r in sorted_dims if r > 0)
strong = sum(1 for d, r in sorted_dims if r >= 0.2)
n_dims = len(sorted_dims)

print(f"\nSummary: {positive}/{n_dims} positive R2, {strong}/{n_dims} strong (>= 0.2)")

## Step 11: Final Summary and Save

In [None]:
import json
import torch
from pathlib import Path

print("="*80)
print("FINAL SUMMARY")
print("="*80)

# Cross-validation results
print(f"\n4-Fold Cross-Validation Results:")
print(f"  Mean R2:       {aggregate_metrics.mean_r2:.4f} +/- {aggregate_metrics.std_r2:.4f}")
print(f"  Mean Pearson:  {aggregate_metrics.mean_pearson:.4f} +/- {aggregate_metrics.std_pearson:.4f}")
print(f"  Mean Spearman: {aggregate_metrics.mean_spearman:.4f} +/- {aggregate_metrics.std_spearman:.4f}")
print(f"  Mean MAE:      {aggregate_metrics.mean_mae:.4f} +/- {aggregate_metrics.std_mae:.4f}")
print(f"  Mean RMSE:     {aggregate_metrics.mean_rmse:.4f} +/- {aggregate_metrics.std_rmse:.4f}")
print(f"  Training time: {aggregate_metrics.total_training_time/60:.1f} minutes")

# Test set results
print(f"\nTest Set (Ensemble of 4 models):")
print(f"  R2:       {test_results['ensemble']['r2']:.4f}")
print(f"  Pearson:  {test_results['ensemble']['pearson']:.4f}")
print(f"  Spearman: {test_results['ensemble']['spearman']:.4f}")
print(f"  MAE:      {test_results['ensemble']['mae']:.4f}")
print(f"  RMSE:     {test_results['ensemble']['rmse']:.4f}")

# Comparison to baselines
print(f"\nComparison to PercePiano baselines:")
print(f"  Bi-LSTM:      R2 = 0.185")
print(f"  MidiBERT:     R2 = 0.313")
print(f"  HAN SOTA:     R2 = 0.397")
print(f"  Ours (CV):    R2 = {aggregate_metrics.mean_r2:.3f} +/- {aggregate_metrics.std_r2:.3f}")
print(f"  Ours (Test):  R2 = {test_results['ensemble']['r2']:.3f}")

# Interpretation
cv_r2 = aggregate_metrics.mean_r2
test_r2 = test_results['ensemble']['r2']

print(f"\nInterpretation:")
if cv_r2 >= 0.35:
    print(f"  [EXCELLENT] CV R2 >= 0.35 matches published SOTA!")
elif cv_r2 >= 0.25:
    print(f"  [GOOD] CV R2 >= 0.25 is usable for pseudo-labeling")
elif cv_r2 >= 0.10:
    print(f"  [FAIR] CV R2 >= 0.10 shows learning, needs improvement")
else:
    print(f"  [NEEDS WORK] CV R2 < 0.10, significant improvement needed")

# Save ensemble model if good enough
if cv_r2 >= 0.25:
    print(f"\nModel qualifies for pseudo-labeling MAESTRO!")

In [None]:
# Final sync to Google Drive
print("="*60)
print("SYNC TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    print(f"\nSyncing all checkpoints and results...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    
    # Also sync fold assignments back to data directory
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print(f"\nSync complete!")
    print(f"  Checkpoints: {GDRIVE_CHECKPOINT_PATH}")
    print(f"  Fold assignments: {GDRIVE_DATA_PATH}")
else:
    print(f"\nrclone not available - skipping sync")

print(f"\n{'='*60}")
print("TRAINING COMPLETE")
print(f"{'='*60}")