# PercePiano Replica Training

Train the PercePiano replica model using preprocessed VirtuosoNet features.

## Attribution

> **PercePiano: A Benchmark for Perceptual Evaluation of Piano Performance**  
> Park, Jongho and Kim, Dasaem et al.  
> ISMIR 2024 / Nature Scientific Reports 2024  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Data

Uses preprocessed VirtuosoNet features (79-dim normalized + 5-dim unnorm for augmentation):
- Train: 945 samples
- Val: 34 samples  
- Test: 115 samples

## Expected Results

- Target R-squared: 0.35-0.40 (piece-split)
- Training time: ~1-2 hours on T4/A100

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_split')
CHECKPOINT_ROOT = '/tmp/checkpoints/percepiano_replica'
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_split'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_replica'

print("="*60)
print("PERCEPIANO REPLICA TRAINING")
print("="*60)

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("rclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

# Restore existing checkpoints
print("\nRestoring checkpoints from Google Drive (if any)...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
    capture_output=False
)

## Step 4: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# PercePiano Configuration (matched to original paper)
CONFIG = {
    # Data
    'data_dir': str(DATA_ROOT),
    
    # Model input (79 normalized features, unnorm used for augmentation only)
    'input_size': 79,
    
    # HAN Architecture (han_bigger256_concat.yml)
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'final_hidden': 128,
    
    # Training (parser.py defaults)
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 32,
    'max_epochs': 100,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '16-mixed',
    
    # Dataset
    'max_notes': 1024,
    
    # Checkpoints
    'checkpoint_dir': CHECKPOINT_ROOT,
}

print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Step 5: Create DataLoaders and Model

In [None]:
from src.percepiano.data.percepiano_vnet_dataset import create_vnet_dataloaders
from src.percepiano.models.percepiano_replica import PercePianoVNetModule

# Create DataLoaders
train_loader, val_loader, test_loader = create_vnet_dataloaders(
    data_dir=CONFIG['data_dir'],
    batch_size=CONFIG['batch_size'],
    max_notes=CONFIG['max_notes'],
    num_workers=0,  # Avoid shared memory issues
)

print(f"Train: {len(train_loader.dataset)} samples")
print(f"Val: {len(val_loader.dataset)} samples")
print(f"Test: {len(test_loader.dataset)} samples")

# Create model
model = PercePianoVNetModule(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    note_layers=CONFIG['note_layers'],
    voice_layers=CONFIG['voice_layers'],
    beat_layers=CONFIG['beat_layers'],
    measure_layers=CONFIG['measure_layers'],
    num_attention_heads=CONFIG['num_attention_heads'],
    final_hidden=CONFIG['final_hidden'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    dropout=CONFIG['dropout'],
)

print(f"\nModel parameters: {model.count_parameters():,}")

## Step 6: Configure Trainer

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=CONFIG['checkpoint_dir'],
    filename='percepiano-{epoch:02d}-{val_mean_r2:.4f}',
    monitor='val/mean_r2',
    mode='max',
    save_top_k=3,
    save_last=True,
)

early_stopping = EarlyStopping(
    monitor='val/mean_r2',
    patience=CONFIG['early_stopping_patience'],
    mode='max',
    verbose=True,
)

lr_monitor = LearningRateMonitor(logging_interval='step')

# Logger
logger = TensorBoardLogger(save_dir='/tmp/logs', name='percepiano_replica')

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    accelerator='gpu',
    devices=1,
    precision=CONFIG['precision'],
    gradient_clip_val=CONFIG['gradient_clip_val'],
    callbacks=[checkpoint_callback, early_stopping, lr_monitor],
    logger=logger,
    log_every_n_steps=10,
    val_check_interval=0.5,
)

print("Trainer configured")

## Step 7: Train

In [None]:
pl.seed_everything(42, workers=True)

print("="*60)
print("STARTING TRAINING")
print("="*60)
print("\nPercePiano SOTA baselines:")
print("  Bi-LSTM: R^2 = 0.185")
print("  MidiBERT: R^2 = 0.313")
print("  Bi-LSTM + SA + HAN: R^2 = 0.397 (SOTA)")
print("="*60)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Sync checkpoints to Google Drive
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 8: Evaluation

In [None]:
"""
COMPREHENSIVE EVALUATION
========================
This cell performs exhaustive analysis of model performance including:
1. Overall metrics (R2, R, MAE, RMSE)
2. Per-dimension breakdown with all metrics
3. Prediction distribution analysis
4. Residual analysis
5. Best/worst sample analysis
6. Comparison to baselines
7. Diagnostic checks
"""

import torch
import numpy as np
from scipy import stats
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from collections import defaultdict

print("="*80)
print("COMPREHENSIVE MODEL EVALUATION")
print("="*80)

# Load best model
best_path = checkpoint_callback.best_model_path
print(f"\nBest checkpoint: {best_path}")

best_model = PercePianoVNetModule.load_from_checkpoint(best_path)
best_model.eval()
best_model.cuda()

dimensions = list(best_model.dimensions)
n_dims = len(dimensions)

# Collect predictions on ALL splits
results = {}
for split_name, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]:
    all_preds = []
    all_targets = []
    sample_info = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            input_features = batch['input_features'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            scores = batch['scores'].cuda()
            num_notes = batch['num_notes']
            
            note_locations = {
                'beat': batch['note_locations_beat'].cuda(),
                'measure': batch['note_locations_measure'].cuda(),
                'voice': batch['note_locations_voice'].cuda(),
            }
            
            outputs = best_model(
                input_features=input_features,
                note_locations=note_locations,
                attention_mask=attention_mask,
            )
            
            all_preds.append(outputs['predictions'].cpu().numpy())
            all_targets.append(scores.cpu().numpy())
            
            # Store sample info for analysis
            for i in range(len(num_notes)):
                sample_info.append({
                    'batch_idx': batch_idx,
                    'sample_idx': i,
                    'num_notes': num_notes[i].item(),
                })
    
    results[split_name] = {
        'preds': np.concatenate(all_preds),
        'targets': np.concatenate(all_targets),
        'sample_info': sample_info,
    }
    print(f"  {split_name}: {len(results[split_name]['preds'])} samples")

# Use test set for detailed analysis
preds = results['test']['preds']
targets = results['test']['targets']
sample_info = results['test']['sample_info']
n_samples = len(preds)

print(f"\n{'='*80}")
print("1. OVERALL METRICS")
print("="*80)

# Calculate overall metrics
overall_r2 = r2_score(targets, preds)
overall_r2_per_sample = r2_score(targets, preds, multioutput='raw_values')
overall_mae = mean_absolute_error(targets, preds)
overall_rmse = np.sqrt(mean_squared_error(targets, preds))

# Flatten for correlation
flat_preds = preds.flatten()
flat_targets = targets.flatten()
overall_pearson_r, overall_pearson_p = stats.pearsonr(flat_targets, flat_preds)
overall_spearman_r, overall_spearman_p = stats.spearmanr(flat_targets, flat_preds)

print(f"\n  Test Set ({n_samples} samples, {n_dims} dimensions)")
print(f"  {'-'*50}")
print(f"  R-squared (R2):        {overall_r2:+.4f}")
print(f"  Pearson R:             {overall_pearson_r:+.4f}  (p={overall_pearson_p:.2e})")
print(f"  Spearman R:            {overall_spearman_r:+.4f}  (p={overall_spearman_p:.2e})")
print(f"  MAE:                   {overall_mae:.4f}")
print(f"  RMSE:                  {overall_rmse:.4f}")

# Interpretation
print(f"\n  Interpretation:")
if overall_r2 >= 0.35:
    print(f"    [EXCELLENT] R2 >= 0.35 matches published SOTA")
elif overall_r2 >= 0.25:
    print(f"    [GOOD] R2 >= 0.25 is usable for pseudo-labeling")
elif overall_r2 >= 0.10:
    print(f"    [FAIR] R2 >= 0.10 shows some learning, needs improvement")
elif overall_r2 >= 0:
    print(f"    [POOR] R2 > 0 but barely better than mean prediction")
else:
    print(f"    [FAILED] R2 < 0 means model is WORSE than predicting the mean!")
    print(f"    This indicates a fundamental problem with data or architecture.")

print(f"\n{'='*80}")
print("2. PER-DIMENSION DETAILED METRICS")
print("="*80)

# Calculate per-dimension metrics
dim_metrics = []
for i, dim in enumerate(dimensions):
    y_true = targets[:, i]
    y_pred = preds[:, i]
    
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    pearson_r, pearson_p = stats.pearsonr(y_true, y_pred)
    spearman_r, spearman_p = stats.spearmanr(y_true, y_pred)
    
    # Prediction statistics
    pred_mean = y_pred.mean()
    pred_std = y_pred.std()
    target_mean = y_true.mean()
    target_std = y_true.std()
    
    # Residual analysis
    residuals = y_true - y_pred
    residual_mean = residuals.mean()
    residual_std = residuals.std()
    
    dim_metrics.append({
        'dim': dim,
        'r2': r2,
        'r': pearson_r,
        'r_p': pearson_p,
        'spearman': spearman_r,
        'mae': mae,
        'rmse': rmse,
        'pred_mean': pred_mean,
        'pred_std': pred_std,
        'target_mean': target_mean,
        'target_std': target_std,
        'residual_mean': residual_mean,
        'residual_std': residual_std,
    })

# Sort by R2
dim_metrics.sort(key=lambda x: x['r2'], reverse=True)

# Print detailed table
print(f"\n  {'Dimension':<22} {'R2':>8} {'R':>8} {'MAE':>8} {'RMSE':>8} {'Status':<12}")
print(f"  {'-'*22} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*12}")

for m in dim_metrics:
    # Status indicator
    if m['r2'] >= 0.3:
        status = "[GOOD]"
    elif m['r2'] >= 0.1:
        status = "[OK]"
    elif m['r2'] >= 0:
        status = "[WEAK]"
    else:
        status = "[FAILED]"
    
    print(f"  {m['dim']:<22} {m['r2']:>+8.4f} {m['r']:>+8.4f} {m['mae']:>8.4f} {m['rmse']:>8.4f} {status:<12}")

# Summary statistics
positive_r2 = sum(1 for m in dim_metrics if m['r2'] > 0)
strong_r2 = sum(1 for m in dim_metrics if m['r2'] >= 0.2)
negative_r2 = sum(1 for m in dim_metrics if m['r2'] < 0)

print(f"\n  Summary:")
print(f"    Positive R2 (> 0):    {positive_r2}/{n_dims}")
print(f"    Strong R2 (>= 0.2):   {strong_r2}/{n_dims}")
print(f"    Negative R2 (< 0):    {negative_r2}/{n_dims}")

print(f"\n{'='*80}")
print("3. PREDICTION DISTRIBUTION ANALYSIS")
print("="*80)

print(f"\n  Checking for prediction collapse or bias...")
print(f"\n  {'Dimension':<22} {'Pred Mean':>10} {'Pred Std':>10} {'Tgt Mean':>10} {'Tgt Std':>10} {'Issue':<15}")
print(f"  {'-'*22} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*15}")

collapsed_dims = []
biased_dims = []

for m in dim_metrics:
    issues = []
    
    # Check for collapsed predictions (very low std)
    if m['pred_std'] < 0.05:
        issues.append("COLLAPSED")
        collapsed_dims.append(m['dim'])
    
    # Check for systematic bias (mean shift > 0.1)
    bias = abs(m['pred_mean'] - m['target_mean'])
    if bias > 0.1:
        issues.append(f"BIAS={bias:.2f}")
        biased_dims.append(m['dim'])
    
    # Check for under-variation (pred_std << target_std)
    if m['pred_std'] < m['target_std'] * 0.5:
        issues.append("LOW_VAR")
    
    issue_str = ", ".join(issues) if issues else "OK"
    
    print(f"  {m['dim']:<22} {m['pred_mean']:>10.4f} {m['pred_std']:>10.4f} {m['target_mean']:>10.4f} {m['target_std']:>10.4f} {issue_str:<15}")

if collapsed_dims:
    print(f"\n  WARNING: {len(collapsed_dims)} dimensions have collapsed predictions!")
    print(f"    Collapsed: {collapsed_dims}")
if biased_dims:
    print(f"\n  WARNING: {len(biased_dims)} dimensions show systematic bias!")
    print(f"    Biased: {biased_dims}")

print(f"\n{'='*80}")
print("4. RESIDUAL ANALYSIS")
print("="*80)

print(f"\n  {'Dimension':<22} {'Residual Mean':>14} {'Residual Std':>14} {'Skewness':>10} {'Kurtosis':>10}")
print(f"  {'-'*22} {'-'*14} {'-'*14} {'-'*10} {'-'*10}")

for m in dim_metrics:
    y_true = targets[:, dimensions.index(m['dim'])]
    y_pred = preds[:, dimensions.index(m['dim'])]
    residuals = y_true - y_pred
    
    skewness = stats.skew(residuals)
    kurtosis = stats.kurtosis(residuals)
    
    print(f"  {m['dim']:<22} {m['residual_mean']:>+14.4f} {m['residual_std']:>14.4f} {skewness:>+10.2f} {kurtosis:>+10.2f}")

print(f"\n{'='*80}")
print("5. CROSS-SPLIT COMPARISON")
print("="*80)

print(f"\n  Checking for overfitting (train >> test) or data issues...")
print(f"\n  {'Split':<10} {'R2':>10} {'MAE':>10} {'RMSE':>10} {'Samples':>10}")
print(f"  {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")

split_metrics = {}
for split_name in ['train', 'val', 'test']:
    p = results[split_name]['preds']
    t = results[split_name]['targets']
    
    r2 = r2_score(t, p)
    mae = mean_absolute_error(t, p)
    rmse = np.sqrt(mean_squared_error(t, p))
    
    split_metrics[split_name] = {'r2': r2, 'mae': mae, 'rmse': rmse, 'n': len(p)}
    
    print(f"  {split_name:<10} {r2:>+10.4f} {mae:>10.4f} {rmse:>10.4f} {len(p):>10}")

# Check for overfitting
train_r2 = split_metrics['train']['r2']
test_r2 = split_metrics['test']['r2']
overfit_gap = train_r2 - test_r2

print(f"\n  Overfitting analysis:")
print(f"    Train-Test R2 gap: {overfit_gap:+.4f}")
if overfit_gap > 0.2:
    print(f"    [WARNING] Large gap suggests overfitting!")
elif overfit_gap > 0.1:
    print(f"    [CAUTION] Moderate gap, some overfitting present")
else:
    print(f"    [OK] Gap is acceptable")

print(f"\n{'='*80}")
print("6. COMPARISON TO BASELINES")
print("="*80)

baselines = {
    'Mean Prediction': 0.0,
    'Random': -0.5,
    'Bi-LSTM (published)': 0.185,
    'MidiBERT (published)': 0.313,
    'HAN SOTA (published)': 0.397,
}

print(f"\n  {'Model':<30} {'R2':>10} {'vs Ours':>12}")
print(f"  {'-'*30} {'-'*10} {'-'*12}")

for name, r2 in sorted(baselines.items(), key=lambda x: x[1]):
    diff = overall_r2 - r2
    diff_str = f"{diff:+.4f}" if diff != 0 else "---"
    marker = " <-- US" if name == 'Mean Prediction' and overall_r2 < 0.05 else ""
    print(f"  {name:<30} {r2:>+10.4f} {diff_str:>12}{marker}")

print(f"  {'-'*30} {'-'*10} {'-'*12}")
print(f"  {'Our Model':<30} {overall_r2:>+10.4f} {'---':>12}")

# Determine rank
our_rank = sum(1 for _, r2 in baselines.items() if r2 > overall_r2) + 1
print(f"\n  Our model ranks #{our_rank} out of {len(baselines) + 1} models")

print(f"\n{'='*80}")
print("7. WORST AND BEST SAMPLES")
print("="*80)

# Calculate per-sample error
sample_errors = np.mean(np.abs(targets - preds), axis=1)
sample_r2s = [r2_score(targets[i], preds[i]) for i in range(n_samples)]

# Worst samples
worst_indices = np.argsort(sample_errors)[-5:][::-1]
print(f"\n  WORST 5 SAMPLES (highest MAE):")
print(f"  {'Index':<8} {'MAE':>10} {'Num Notes':>12}")
print(f"  {'-'*8} {'-'*10} {'-'*12}")
for idx in worst_indices:
    info = sample_info[idx]
    print(f"  {idx:<8} {sample_errors[idx]:>10.4f} {info['num_notes']:>12}")

# Best samples
best_indices = np.argsort(sample_errors)[:5]
print(f"\n  BEST 5 SAMPLES (lowest MAE):")
print(f"  {'Index':<8} {'MAE':>10} {'Num Notes':>12}")
print(f"  {'-'*8} {'-'*10} {'-'*12}")
for idx in best_indices:
    info = sample_info[idx]
    print(f"  {idx:<8} {sample_errors[idx]:>10.4f} {info['num_notes']:>12}")

# Correlation between num_notes and error
num_notes_arr = np.array([info['num_notes'] for info in sample_info])
corr, p_val = stats.pearsonr(num_notes_arr, sample_errors)
print(f"\n  Correlation (num_notes vs error): r={corr:+.3f} (p={p_val:.3f})")
if abs(corr) > 0.3:
    if corr > 0:
        print(f"    [WARNING] Longer pieces have higher errors - model struggles with length")
    else:
        print(f"    [INFO] Shorter pieces have higher errors - may need more context")

print(f"\n{'='*80}")
print("8. DIAGNOSTIC SUMMARY")
print("="*80)

issues = []
warnings = []
good = []

# Check overall performance
if overall_r2 < 0:
    issues.append(f"R2 is negative ({overall_r2:.4f}) - model worse than mean prediction")
elif overall_r2 < 0.1:
    warnings.append(f"R2 is low ({overall_r2:.4f}) - barely learning")
else:
    good.append(f"R2 is acceptable ({overall_r2:.4f})")

# Check for collapsed dimensions
if len(collapsed_dims) > 0:
    issues.append(f"{len(collapsed_dims)} dimensions have collapsed predictions")

# Check for negative R2 dimensions
if negative_r2 > n_dims // 2:
    issues.append(f"Majority of dimensions ({negative_r2}/{n_dims}) have negative R2")
elif negative_r2 > 0:
    warnings.append(f"{negative_r2} dimensions have negative R2")

# Check overfitting
if overfit_gap > 0.2:
    warnings.append(f"Significant overfitting (train-test gap: {overfit_gap:.3f})")

# Check validation set size
val_size = len(results['val']['preds'])
if val_size < 50:
    warnings.append(f"Validation set very small ({val_size} samples) - metrics may be unstable")

print(f"\n  ISSUES ({len(issues)}):")
if issues:
    for issue in issues:
        print(f"    [X] {issue}")
else:
    print(f"    None!")

print(f"\n  WARNINGS ({len(warnings)}):")
if warnings:
    for warning in warnings:
        print(f"    [!] {warning}")
else:
    print(f"    None!")

print(f"\n  GOOD ({len(good)}):")
if good:
    for g in good:
        print(f"    [+] {g}")

print(f"\n{'='*80}")
print("EVALUATION COMPLETE")
print("="*80)

In [None]:
# Save teacher model and final sync
import torch
from pathlib import Path

teacher_path = Path(CONFIG['checkpoint_dir']) / 'percepiano_teacher.pt'

print("="*60)
print("SAVE TEACHER MODEL")
print("="*60)

if overall_r2 >= 0.25:
    torch.save({
        'state_dict': best_model.state_dict(),
        'config': {
            'input_size': CONFIG['input_size'],
            'hidden_size': CONFIG['hidden_size'],
            'note_layers': CONFIG['note_layers'],
            'voice_layers': CONFIG['voice_layers'],
            'beat_layers': CONFIG['beat_layers'],
            'measure_layers': CONFIG['measure_layers'],
            'num_attention_heads': CONFIG['num_attention_heads'],
            'final_hidden': CONFIG['final_hidden'],
            'dropout': CONFIG['dropout'],
        },
        'dimensions': dimensions,
        'metrics': {
            'overall_r2': overall_r2,
            'overall_mae': overall_mae,
            'overall_rmse': overall_rmse,
            'pearson_r': overall_pearson_r,
            'per_dimension': {m['dim']: {'r2': m['r2'], 'r': m['r'], 'mae': m['mae']} for m in dim_metrics},
            'split_metrics': split_metrics,
        },
    }, teacher_path)
    
    print(f"\n  Saved teacher model to: {teacher_path}")
    print(f"  Teacher R2: {overall_r2:.4f}")
    print(f"  Teacher MAE: {overall_mae:.4f}")
    print(f"\n  This model can be used for pseudo-labeling MAESTRO!")
else:
    print(f"\n  R2 = {overall_r2:.4f} is below threshold (0.25) for teacher model.")
    print(f"  Not saving teacher model.")
    print(f"\n  Consider:")
    print(f"    1. Training for more epochs")
    print(f"    2. Checking data quality")
    print(f"    3. Reviewing the diagnostic summary above")

# Final sync to Google Drive
print(f"\n{'='*60}")
print("SYNC TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    print(f"\n  Syncing checkpoints to: {GDRIVE_CHECKPOINT_PATH}")
    subprocess.run(
        ['rclone', 'copy', CONFIG['checkpoint_dir'], GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print(f"\n  Sync complete!")
    print(f"  Remote location: {GDRIVE_CHECKPOINT_PATH}")
else:
    print(f"\n  rclone not available - skipping sync")

print(f"\n{'='*60}")
print("TRAINING COMPLETE")
print("="*60)