# PercePiano Replica Training (Google Colab)

Train the PercePiano replica model using preprocessed VirtuosoNet features.

## Attribution

> **PercePiano: A Benchmark for Perceptual Evaluation of Piano Performance**  
> Park, Jongho and Kim, Dasaem et al.  
> ISMIR 2024 / Nature Scientific Reports 2024  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Data

Uses preprocessed VirtuosoNet features (79-dim normalized + 5-dim unnorm for augmentation):
- Train: 945 samples
- Val: 34 samples  
- Test: 115 samples

## Expected Results

- Target R-squared: 0.35-0.40 (piece-split)
- Training time: ~1-2 hours on T4/A100

## Colab-Specific Features

- Native Google Drive mounting (no rclone config needed)
- Colab Secrets API for sensitive configuration
- Runtime detection and GPU info
- Session keep-alive utilities

## Step 1: Environment Detection and Setup

In [None]:
# Detect runtime environment
import sys
import os

def is_colab():
    """Detect if running in Google Colab."""
    try:
        import google.colab
        return True
    except ImportError:
        return False

RUNNING_IN_COLAB = is_colab()
print(f"Running in Colab: {RUNNING_IN_COLAB}")
print(f"Python version: {sys.version}")

# Check GPU
import torch
print(f"\nCUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Memory: {gpu_mem:.1f} GB")
    
    # GPU type recommendations
    if gpu_mem >= 14:
        print("\n[INFO] High-memory GPU detected (A100/V100) - can use larger batch sizes")
    elif gpu_mem >= 8:
        print("\n[INFO] T4/P100 detected - standard batch size recommended")
    else:
        print("\n[WARN] Low-memory GPU - may need to reduce batch size")
else:
    print("\n[WARN] No GPU detected! Go to Runtime > Change runtime type > GPU")

In [None]:
# Install dependencies
if RUNNING_IN_COLAB:
    # Install uv for fast package management
    !curl -LsSf https://astral.sh/uv/install.sh | sh 2>/dev/null
    os.environ['PATH'] = f"{os.environ['HOME']}/.local/bin:{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
REPO_PATH = '/content/crescendai' if RUNNING_IN_COLAB else '/tmp/crescendai'

if not os.path.exists(REPO_PATH):
    !git clone https://github.com/Jai-Dhiman/crescendai.git {REPO_PATH}

%cd {REPO_PATH}/model
!git pull
!git log -1 --oneline

# Install package and dependencies
!uv pip install --system -e . 2>/dev/null || pip install -e .
!pip install tensorboard rich -q

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Google Drive Integration

Two options for accessing data from Google Drive:
1. **Native Drive Mount** (recommended for Colab) - Simple, fast for small-medium datasets
2. **rclone** - Better for very large datasets or external cloud storage

The native mount is preferred since it requires no additional configuration.

In [None]:
from pathlib import Path
import subprocess

# Configuration
USE_NATIVE_DRIVE_MOUNT = True  # Set to False to use rclone instead

# Path configuration
if RUNNING_IN_COLAB:
    DRIVE_MOUNT_POINT = '/content/drive'
    DATA_ROOT = Path('/content/percepiano_vnet_split')
    CHECKPOINT_ROOT = '/content/checkpoints/percepiano_replica'
else:
    DRIVE_MOUNT_POINT = '/tmp/drive'
    DATA_ROOT = Path('/tmp/percepiano_vnet_split')
    CHECKPOINT_ROOT = '/tmp/checkpoints/percepiano_replica'

# Google Drive paths (relative to mount point)
GDRIVE_DATA_REL_PATH = 'MyDrive/crescendai_data/percepiano_vnet_split'
GDRIVE_CHECKPOINT_REL_PATH = 'MyDrive/crescendai_checkpoints/percepiano_replica'

# For rclone (if used)
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_split'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_replica'

print("="*60)
print("PERCEPIANO REPLICA TRAINING (COLAB)")
print("="*60)
print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
# Mount Google Drive
DRIVE_AVAILABLE = False
RCLONE_AVAILABLE = False

if USE_NATIVE_DRIVE_MOUNT and RUNNING_IN_COLAB:
    print("Using native Google Drive mount...")
    from google.colab import drive
    drive.mount(DRIVE_MOUNT_POINT)
    
    # Check if data exists
    drive_data_path = Path(DRIVE_MOUNT_POINT) / GDRIVE_DATA_REL_PATH
    if drive_data_path.exists():
        DRIVE_AVAILABLE = True
        print(f"\nData found at: {drive_data_path}")
    else:
        print(f"\n[WARN] Data not found at: {drive_data_path}")
        print("Please ensure your data is at: My Drive/crescendai_data/percepiano_vnet_split/")
else:
    # Try rclone
    print("Checking rclone configuration...")
    
    # Install rclone if needed
    result = subprocess.run(['which', 'rclone'], capture_output=True)
    if result.returncode != 0:
        print("Installing rclone...")
        !curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"
    
    result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
    if 'gdrive:' in result.stdout:
        RCLONE_AVAILABLE = True
        print("rclone 'gdrive' remote: CONFIGURED")
    else:
        print("rclone 'gdrive' remote: NOT CONFIGURED")
        print("\nTo configure rclone with Colab Secrets:")
        print("  1. Run 'rclone config' locally to create gdrive remote")
        print("  2. Copy ~/.config/rclone/rclone.conf contents")
        print("  3. Add as Colab Secret named 'RCLONE_CONFIG'")
        print("  4. Re-run this cell")
        
        # Try to load from Colab Secrets
        if RUNNING_IN_COLAB:
            try:
                from google.colab import userdata
                rclone_config = userdata.get('RCLONE_CONFIG')
                if rclone_config:
                    os.makedirs(os.path.expanduser('~/.config/rclone'), exist_ok=True)
                    with open(os.path.expanduser('~/.config/rclone/rclone.conf'), 'w') as f:
                        f.write(rclone_config)
                    print("\nLoaded rclone config from Colab Secrets!")
                    RCLONE_AVAILABLE = True
            except Exception as e:
                print(f"\nCould not load RCLONE_CONFIG from secrets: {e}")

print(f"\nDrive mount available: {DRIVE_AVAILABLE}")
print(f"rclone available: {RCLONE_AVAILABLE}")

## Step 3: Download/Copy Data

In [None]:
import shutil

def copy_data_from_drive(src_path: Path, dst_path: Path):
    """Copy data from mounted Drive to local storage for faster I/O."""
    print(f"Copying data from {src_path} to {dst_path}...")
    
    for split in ['train', 'val', 'test']:
        src_split = src_path / split
        dst_split = dst_path / split
        
        if src_split.exists():
            dst_split.mkdir(parents=True, exist_ok=True)
            files = list(src_split.glob('*.pkl'))
            print(f"  {split}: copying {len(files)} files...")
            for f in files:
                shutil.copy2(f, dst_split / f.name)
        else:
            print(f"  {split}: NOT FOUND at {src_split}")
    
    # Copy stat.pkl
    stat_file = src_path / 'stat.pkl'
    if stat_file.exists():
        shutil.copy2(stat_file, dst_path / 'stat.pkl')
        print(f"  stat.pkl: copied")
    else:
        print(f"  stat.pkl: NOT FOUND")

# Copy data based on available method
if DRIVE_AVAILABLE:
    drive_data_path = Path(DRIVE_MOUNT_POINT) / GDRIVE_DATA_REL_PATH
    copy_data_from_drive(drive_data_path, DATA_ROOT)
    
    # Also restore checkpoints
    drive_ckpt_path = Path(DRIVE_MOUNT_POINT) / GDRIVE_CHECKPOINT_REL_PATH
    if drive_ckpt_path.exists():
        print(f"\nRestoring checkpoints from {drive_ckpt_path}...")
        for f in drive_ckpt_path.glob('*'):
            shutil.copy2(f, Path(CHECKPOINT_ROOT) / f.name)
        print("Checkpoints restored!")

elif RCLONE_AVAILABLE:
    print("Downloading data using rclone...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
        capture_output=False
    )
    
    print("\nRestoring checkpoints...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
        capture_output=False
    )

else:
    raise RuntimeError(
        "No data source available!\n"
        "Either:\n"
        "  1. Mount Google Drive and ensure data exists at the correct path\n"
        "  2. Configure rclone with 'gdrive' remote\n"
        "  3. Upload data manually to Colab"
    )

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

## Step 4: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Adjust batch size based on GPU memory
if torch.cuda.is_available():
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    if gpu_mem >= 14:
        batch_size = 64  # A100
    elif gpu_mem >= 8:
        batch_size = 32  # T4/P100
    else:
        batch_size = 16  # K80 or lower
else:
    batch_size = 8  # CPU fallback

# PercePiano Configuration (matched to original paper)
CONFIG = {
    # Data
    'data_dir': str(DATA_ROOT),
    
    # Model input (79 normalized features, unnorm used for augmentation only)
    'input_size': 79,
    
    # HAN Architecture (han_bigger256_concat.yml)
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'final_hidden': 128,
    
    # Training (parser.py defaults)
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': batch_size,
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '16-mixed',
    
    # Dataset
    'max_notes': 1024,
    
    # Checkpoints
    'checkpoint_dir': CHECKPOINT_ROOT,
}

print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Step 5: Create DataLoaders and Model

In [None]:
from src.percepiano.data.percepiano_vnet_dataset import create_vnet_dataloaders
from src.percepiano.models.percepiano_replica import PercePianoVNetModule

# Create DataLoaders
train_loader, val_loader, test_loader = create_vnet_dataloaders(
    data_dir=CONFIG['data_dir'],
    batch_size=CONFIG['batch_size'],
    max_notes=CONFIG['max_notes'],
    num_workers=2 if RUNNING_IN_COLAB else 0,  # Colab supports workers
)

print(f"Train: {len(train_loader.dataset)} samples")
print(f"Val: {len(val_loader.dataset)} samples")
print(f"Test: {len(test_loader.dataset)} samples")

# Create model
model = PercePianoVNetModule(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    note_layers=CONFIG['note_layers'],
    voice_layers=CONFIG['voice_layers'],
    beat_layers=CONFIG['beat_layers'],
    measure_layers=CONFIG['measure_layers'],
    num_attention_heads=CONFIG['num_attention_heads'],
    final_hidden=CONFIG['final_hidden'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    dropout=CONFIG['dropout'],
)

print(f"\nModel parameters: {model.count_parameters():,}")

## Step 6: Configure Trainer

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=CONFIG['checkpoint_dir'],
    filename='percepiano-{epoch:02d}-{val_mean_r2:.4f}',
    monitor='val/mean_r2',
    mode='max',
    save_top_k=3,
    save_last=True,
)

early_stopping = EarlyStopping(
    monitor='val/mean_r2',
    patience=CONFIG['early_stopping_patience'],
    mode='max',
    verbose=True,
)

lr_monitor = LearningRateMonitor(logging_interval='step')

# Logger - use Colab-friendly paths
log_dir = '/content/logs' if RUNNING_IN_COLAB else '/tmp/logs'
logger = TensorBoardLogger(save_dir=log_dir, name='percepiano_replica')

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    precision=CONFIG['precision'] if torch.cuda.is_available() else 32,
    gradient_clip_val=CONFIG['gradient_clip_val'],
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    logger=logger,
    log_every_n_steps=10,
    val_check_interval=0.5,
)

print("Trainer configured")

# TensorBoard in Colab
if RUNNING_IN_COLAB:
    print("\nTo view TensorBoard, run in a new cell:")
    print("  %load_ext tensorboard")
    print(f"  %tensorboard --logdir {log_dir}")

In [None]:
# Optional: Load TensorBoard (run this cell to enable inline TensorBoard)
if RUNNING_IN_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir /content/logs

## Step 7: Session Keep-Alive (Optional)

Colab may disconnect after 90 minutes of inactivity. This cell sets up a simple keep-alive mechanism.

In [None]:
# Keep-alive for long training sessions (optional)
if RUNNING_IN_COLAB:
    import IPython
    from google.colab import output
    
    # This creates a simple keep-alive by periodically pinging
    # Note: This doesn't guarantee session persistence, but helps
    display(IPython.display.Javascript('''
        function KeepClicking() {
            console.log("Keeping session alive...");
            // Simulate activity every 60 seconds
            setTimeout(KeepClicking, 60000);
        }
        KeepClicking();
    '''))
    print("Keep-alive enabled (pings every 60s)")
    print("Note: Colab Pro+ allows up to 24h sessions")

## Step 8: Train

In [None]:
pl.seed_everything(42, workers=True)

print("="*60)
print("STARTING TRAINING")
print("="*60)
print("\nPercePiano SOTA baselines:")
print("  Bi-LSTM: R^2 = 0.185")
print("  MidiBERT: R^2 = 0.313")
print("  Bi-LSTM + SA + HAN: R^2 = 0.397 (SOTA)")
print("="*60)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Sync checkpoints back to Google Drive
print("Syncing checkpoints to Google Drive...")

if DRIVE_AVAILABLE:
    # Use native drive mount
    drive_ckpt_path = Path(DRIVE_MOUNT_POINT) / GDRIVE_CHECKPOINT_REL_PATH
    drive_ckpt_path.mkdir(parents=True, exist_ok=True)
    
    for f in Path(CHECKPOINT_ROOT).glob('*'):
        shutil.copy2(f, drive_ckpt_path / f.name)
    
    print(f"Checkpoints synced to: {drive_ckpt_path}")
    
elif RCLONE_AVAILABLE:
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print("Checkpoints synced via rclone!")
else:
    print("[WARN] No sync method available - download checkpoints manually!")
    if RUNNING_IN_COLAB:
        from google.colab import files
        print("Checkpoints available at:", CHECKPOINT_ROOT)
        print("Use: files.download(path) to download specific files")

## Step 9: Evaluation

In [None]:
"""
COMPREHENSIVE EVALUATION
========================
This cell performs exhaustive analysis of model performance including:
1. Overall metrics (R2, R, MAE, RMSE)
2. Per-dimension breakdown with all metrics
3. Prediction distribution analysis
4. Residual analysis
5. Best/worst sample analysis
6. Comparison to baselines
7. Diagnostic checks
"""

import torch
import numpy as np
from scipy import stats
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from collections import defaultdict

print("="*80)
print("COMPREHENSIVE MODEL EVALUATION")
print("="*80)

# Load best model
best_path = checkpoint_callback.best_model_path
print(f"\nBest checkpoint: {best_path}")

best_model = PercePianoVNetModule.load_from_checkpoint(best_path)
best_model.eval()
if torch.cuda.is_available():
    best_model.cuda()

dimensions = list(best_model.dimensions)
n_dims = len(dimensions)

# Collect predictions on ALL splits
results = {}
for split_name, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]:
    all_preds = []
    all_targets = []
    sample_info = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            input_features = batch['input_features'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            scores = batch['scores'].to(device)
            num_notes = batch['num_notes']
            
            note_locations = {
                'beat': batch['note_locations_beat'].to(device),
                'measure': batch['note_locations_measure'].to(device),
                'voice': batch['note_locations_voice'].to(device),
            }
            
            outputs = best_model(
                input_features=input_features,
                note_locations=note_locations,
                attention_mask=attention_mask,
            )
            
            all_preds.append(outputs['predictions'].cpu().numpy())
            all_targets.append(scores.cpu().numpy())
            
            # Store sample info for analysis
            for i in range(len(num_notes)):
                sample_info.append({
                    'batch_idx': batch_idx,
                    'sample_idx': i,
                    'num_notes': num_notes[i].item(),
                })
    
    results[split_name] = {
        'preds': np.concatenate(all_preds),
        'targets': np.concatenate(all_targets),
        'sample_info': sample_info,
    }
    print(f"  {split_name}: {len(results[split_name]['preds'])} samples")

# Use test set for detailed analysis
preds = results['test']['preds']
targets = results['test']['targets']
sample_info = results['test']['sample_info']
n_samples = len(preds)

print(f"\n{'='*80}")
print("1. OVERALL METRICS")
print("="*80)

# Calculate overall metrics
overall_r2 = r2_score(targets, preds)
overall_r2_per_sample = r2_score(targets, preds, multioutput='raw_values')
overall_mae = mean_absolute_error(targets, preds)
overall_rmse = np.sqrt(mean_squared_error(targets, preds))

# Flatten for correlation
flat_preds = preds.flatten()
flat_targets = targets.flatten()
overall_pearson_r, overall_pearson_p = stats.pearsonr(flat_targets, flat_preds)
overall_spearman_r, overall_spearman_p = stats.spearmanr(flat_targets, flat_preds)

print(f"\n  Test Set ({n_samples} samples, {n_dims} dimensions)")
print(f"  {'-'*50}")
print(f"  R-squared (R2):        {overall_r2:+.4f}")
print(f"  Pearson R:             {overall_pearson_r:+.4f}  (p={overall_pearson_p:.2e})")
print(f"  Spearman R:            {overall_spearman_r:+.4f}  (p={overall_spearman_p:.2e})")
print(f"  MAE:                   {overall_mae:.4f}")
print(f"  RMSE:                  {overall_rmse:.4f}")

# Interpretation
print(f"\n  Interpretation:")
if overall_r2 >= 0.35:
    print(f"    [EXCELLENT] R2 >= 0.35 matches published SOTA")
elif overall_r2 >= 0.25:
    print(f"    [GOOD] R2 >= 0.25 is usable for pseudo-labeling")
elif overall_r2 >= 0.10:
    print(f"    [FAIR] R2 >= 0.10 shows some learning, needs improvement")
elif overall_r2 >= 0:
    print(f"    [POOR] R2 > 0 but barely better than mean prediction")
else:
    print(f"    [FAILED] R2 < 0 means model is WORSE than predicting the mean!")
    print(f"    This indicates a fundamental problem with data or architecture.")

In [None]:
print(f"\n{'='*80}")
print("2. PER-DIMENSION DETAILED METRICS")
print("="*80)

# Calculate per-dimension metrics
dim_metrics = []
for i, dim in enumerate(dimensions):
    y_true = targets[:, i]
    y_pred = preds[:, i]
    
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    pearson_r, pearson_p = stats.pearsonr(y_true, y_pred)
    spearman_r, spearman_p = stats.spearmanr(y_true, y_pred)
    
    # Prediction statistics
    pred_mean = y_pred.mean()
    pred_std = y_pred.std()
    target_mean = y_true.mean()
    target_std = y_true.std()
    
    # Residual analysis
    residuals = y_true - y_pred
    residual_mean = residuals.mean()
    residual_std = residuals.std()
    
    dim_metrics.append({
        'dim': dim,
        'r2': r2,
        'r': pearson_r,
        'r_p': pearson_p,
        'spearman': spearman_r,
        'mae': mae,
        'rmse': rmse,
        'pred_mean': pred_mean,
        'pred_std': pred_std,
        'target_mean': target_mean,
        'target_std': target_std,
        'residual_mean': residual_mean,
        'residual_std': residual_std,
    })

# Sort by R2
dim_metrics.sort(key=lambda x: x['r2'], reverse=True)

# Print detailed table
print(f"\n  {'Dimension':<22} {'R2':>8} {'R':>8} {'MAE':>8} {'RMSE':>8} {'Status':<12}")
print(f"  {'-'*22} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*12}")

for m in dim_metrics:
    # Status indicator
    if m['r2'] >= 0.3:
        status = "[GOOD]"
    elif m['r2'] >= 0.1:
        status = "[OK]"
    elif m['r2'] >= 0:
        status = "[WEAK]"
    else:
        status = "[FAILED]"
    
    print(f"  {m['dim']:<22} {m['r2']:>+8.4f} {m['r']:>+8.4f} {m['mae']:>8.4f} {m['rmse']:>8.4f} {status:<12}")

# Summary statistics
positive_r2 = sum(1 for m in dim_metrics if m['r2'] > 0)
strong_r2 = sum(1 for m in dim_metrics if m['r2'] >= 0.2)
negative_r2 = sum(1 for m in dim_metrics if m['r2'] < 0)

print(f"\n  Summary:")
print(f"    Positive R2 (> 0):    {positive_r2}/{n_dims}")
print(f"    Strong R2 (>= 0.2):   {strong_r2}/{n_dims}")
print(f"    Negative R2 (< 0):    {negative_r2}/{n_dims}")

In [None]:
print(f"\n{'='*80}")
print("3. CROSS-SPLIT COMPARISON")
print("="*80)

print(f"\n  Checking for overfitting (train >> test) or data issues...")
print(f"\n  {'Split':<10} {'R2':>10} {'MAE':>10} {'RMSE':>10} {'Samples':>10}")
print(f"  {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")

split_metrics = {}
for split_name in ['train', 'val', 'test']:
    p = results[split_name]['preds']
    t = results[split_name]['targets']
    
    r2 = r2_score(t, p)
    mae = mean_absolute_error(t, p)
    rmse = np.sqrt(mean_squared_error(t, p))
    
    split_metrics[split_name] = {'r2': r2, 'mae': mae, 'rmse': rmse, 'n': len(p)}
    
    print(f"  {split_name:<10} {r2:>+10.4f} {mae:>10.4f} {rmse:>10.4f} {len(p):>10}")

# Check for overfitting
train_r2 = split_metrics['train']['r2']
test_r2 = split_metrics['test']['r2']
overfit_gap = train_r2 - test_r2

print(f"\n  Overfitting analysis:")
print(f"    Train-Test R2 gap: {overfit_gap:+.4f}")
if overfit_gap > 0.2:
    print(f"    [WARNING] Large gap suggests overfitting!")
elif overfit_gap > 0.1:
    print(f"    [CAUTION] Moderate gap, some overfitting present")
else:
    print(f"    [OK] Gap is acceptable")

print(f"\n{'='*80}")
print("4. COMPARISON TO BASELINES")
print("="*80)

baselines = {
    'Mean Prediction': 0.0,
    'Random': -0.5,
    'Bi-LSTM (published)': 0.185,
    'MidiBERT (published)': 0.313,
    'HAN SOTA (published)': 0.397,
}

print(f"\n  {'Model':<30} {'R2':>10} {'vs Ours':>12}")
print(f"  {'-'*30} {'-'*10} {'-'*12}")

for name, r2 in sorted(baselines.items(), key=lambda x: x[1]):
    diff = overall_r2 - r2
    diff_str = f"{diff:+.4f}" if diff != 0 else "---"
    marker = " <-- US" if name == 'Mean Prediction' and overall_r2 < 0.05 else ""
    print(f"  {name:<30} {r2:>+10.4f} {diff_str:>12}{marker}")

print(f"  {'-'*30} {'-'*10} {'-'*12}")
print(f"  {'Our Model':<30} {overall_r2:>+10.4f} {'---':>12}")

# Determine rank
our_rank = sum(1 for _, r2 in baselines.items() if r2 > overall_r2) + 1
print(f"\n  Our model ranks #{our_rank} out of {len(baselines) + 1} models")

## Step 10: Save Teacher Model and Final Sync

In [None]:
# Save teacher model and final sync
import torch
from pathlib import Path

teacher_path = Path(CONFIG['checkpoint_dir']) / 'percepiano_teacher.pt'

print("="*60)
print("SAVE TEACHER MODEL")
print("="*60)

if overall_r2 >= 0.25:
    torch.save({
        'state_dict': best_model.state_dict(),
        'config': {
            'input_size': CONFIG['input_size'],
            'hidden_size': CONFIG['hidden_size'],
            'note_layers': CONFIG['note_layers'],
            'voice_layers': CONFIG['voice_layers'],
            'beat_layers': CONFIG['beat_layers'],
            'measure_layers': CONFIG['measure_layers'],
            'num_attention_heads': CONFIG['num_attention_heads'],
            'final_hidden': CONFIG['final_hidden'],
            'dropout': CONFIG['dropout'],
        },
        'dimensions': dimensions,
        'metrics': {
            'overall_r2': overall_r2,
            'overall_mae': overall_mae,
            'overall_rmse': overall_rmse,
            'pearson_r': overall_pearson_r,
            'per_dimension': {m['dim']: {'r2': m['r2'], 'r': m['r'], 'mae': m['mae']} for m in dim_metrics},
            'split_metrics': split_metrics,
        },
    }, teacher_path)
    
    print(f"\n  Saved teacher model to: {teacher_path}")
    print(f"  Teacher R2: {overall_r2:.4f}")
    print(f"  Teacher MAE: {overall_mae:.4f}")
    print(f"\n  This model can be used for pseudo-labeling MAESTRO!")
else:
    print(f"\n  R2 = {overall_r2:.4f} is below threshold (0.25) for teacher model.")
    print(f"  Not saving teacher model.")
    print(f"\n  Consider:")
    print(f"    1. Training for more epochs")
    print(f"    2. Checking data quality")
    print(f"    3. Reviewing the diagnostic summary above")

# Final sync to Google Drive
print(f"\n{'='*60}")
print("FINAL SYNC TO GOOGLE DRIVE")
print("="*60)

if DRIVE_AVAILABLE:
    drive_ckpt_path = Path(DRIVE_MOUNT_POINT) / GDRIVE_CHECKPOINT_REL_PATH
    drive_ckpt_path.mkdir(parents=True, exist_ok=True)
    
    for f in Path(CHECKPOINT_ROOT).glob('*'):
        shutil.copy2(f, drive_ckpt_path / f.name)
    
    print(f"\n  Synced to: {drive_ckpt_path}")
    print(f"  Files synced: {len(list(Path(CHECKPOINT_ROOT).glob('*')))}")
    
elif RCLONE_AVAILABLE:
    print(f"\n  Syncing checkpoints to: {GDRIVE_CHECKPOINT_PATH}")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print(f"\n  Sync complete!")
    
else:
    print(f"\n  No sync method available.")
    print(f"  Checkpoints saved locally at: {CHECKPOINT_ROOT}")
    if RUNNING_IN_COLAB:
        print(f"\n  To download checkpoints:")
        print(f"    from google.colab import files")
        print(f"    files.download('{teacher_path}')")

print(f"\n{'='*60}")
print("TRAINING COMPLETE")
print("="*60)

## Utilities: Download Checkpoints Manually

Run this cell if you need to download checkpoints directly to your local machine.

In [None]:
# Download checkpoints (run only if needed)
if RUNNING_IN_COLAB:
    from google.colab import files
    
    print("Available checkpoints:")
    for f in Path(CHECKPOINT_ROOT).glob('*'):
        size_mb = f.stat().st_size / 1e6
        print(f"  {f.name} ({size_mb:.1f} MB)")
    
    # Uncomment to download specific file:
    # files.download(str(teacher_path))
    
    # Or download best checkpoint:
    # files.download(checkpoint_callback.best_model_path)