# PercePiano Replica Training (4-Fold Cross-Validation)

Train the PercePiano replica model using 4-fold piece-based cross-validation,
matching the methodology and hyperparameters from the PercePiano paper SOTA.

## Attribution

> **PercePiano: Piano Performance Evaluation Dataset with Multi-level Perceptual Features**  
> Park, Kim et al.  
> Nature Scientific Reports 2024  
> Paper: https://pmc.ncbi.nlm.nih.gov/articles/PMC11450231/  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Methodology

Following the exact approach from `m2pf_dataset_compositionfold.py`:

- **Piece-based splits**: All performances of the same piece stay in the same fold
- **Test set**: Select pieces randomly until reaching ~15% of SAMPLES (not pieces)
- **4-fold CV**: Remaining pieces distributed round-robin across folds
- **Per-fold normalization**: Stats computed from training folds only

## Hyperparameters (SOTA Configuration - R2 = 0.397)

These parameters match the published SOTA from `2_run_comp_multilevel_total.sh` and `han_bigger256_concat.yml`:

| Parameter | SOTA Value | Notes |
|-----------|------------|-------|
| input_size | 79 | SOTA uses 79 base features (includes section_tempo) |
| batch_size | 8 | From SOTA training script |
| learning_rate | 2.5e-5 | From SOTA training script |
| hidden_size | 256 | HAN encoder dimension |
| prediction_head | 512->512->19 | From model_m2pf.py (NOT config's final_fc_size) |
| dropout | 0.2 | Regularization |
| augment_train | False | SOTA doesn't use key augmentation |
| max_epochs | 200 | Extended training window |
| early_stopping_patience | 40 | More patience for convergence |
| gradient_clip_val | 2.0 | From parser.py |
| **precision** | **32** | **FP32 (original uses FP32, not mixed precision)** |
| **max_notes/slice_len** | **5000** | **SOTA slice size for overlapping sampling** |

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Clone original PercePiano for comparison (needed for data diagnostics)
PERCEPIANO_PATH = '/tmp/crescendai/model/data/raw/PercePiano'
if not os.path.exists(PERCEPIANO_PATH):
    print("\nCloning original PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {PERCEPIANO_PATH}
else:
    print(f"\nPercePiano already present at {PERCEPIANO_PATH}")

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

# Training control
RESTART_TRAINING = True  # Set to True to clear checkpoints and start fresh

print("="*60)
print("PERCEPIANO REPLICA TRAINING (4-FOLD CV)")
print("="*60)

# Clear checkpoints if restarting
if RESTART_TRAINING and CHECKPOINT_ROOT.exists():
    print(f"\nRESTART_TRAINING=True: Clearing checkpoints at {CHECKPOINT_ROOT}")
    shutil.rmtree(CHECKPOINT_ROOT)
    print("  Checkpoints cleared!")

if RESTART_TRAINING and LOG_ROOT.exists():
    print(f"RESTART_TRAINING=True: Clearing logs at {LOG_ROOT}")
    shutil.rmtree(LOG_ROOT)
    print("  Logs cleared!")

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("\nrclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("\nrclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")
print(f"Log directory: {LOG_ROOT}")
print(f"\nRESTART_TRAINING: {RESTART_TRAINING}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

fold_file = DATA_ROOT / 'fold_assignments.json'
print(f"  fold_assignments.json: {'present' if fold_file.exists() else 'will be created'}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
TEST_RATIO = 0.15
SEED = 42

print("="*60)
print("FOLD ASSIGNMENT CREATION")
print("="*60)

# Force regeneration to use corrected methodology
# - Test set: select pieces until ~15% of SAMPLES (PercePiano methodology)
# - CV folds: greedy bin-packing for balanced sample counts (improvement over round-robin)
FORCE_REGENERATE = True

if FOLD_FILE.exists() and not FORCE_REGENERATE:
    print(f"\nLoading existing fold assignments from {FOLD_FILE}")
    fold_assignments = load_fold_assignments(FOLD_FILE)
else:
    if FOLD_FILE.exists():
        print(f"\nRemoving old fold assignments (regenerating with balanced methodology)...")
        FOLD_FILE.unlink()
    
    print(f"\nCreating new {N_FOLDS}-fold piece-based splits...")
    print("  Test set: select pieces until ~15% of SAMPLES")
    print("  CV folds: greedy bin-packing for balanced sample counts")
    fold_assignments = create_piece_based_folds(
        data_dir=DATA_ROOT,
        n_folds=N_FOLDS,
        test_ratio=TEST_RATIO,
        seed=SEED,
    )
    save_fold_assignments(fold_assignments, FOLD_FILE)

# Print statistics
print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Enable better CUDA error reporting
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import model type constants (including Round 17 true incremental models)
from src.percepiano.training.kfold_trainer import (
    MODEL_TYPE_HAN,
    MODEL_TYPE_BASELINE,
    MODEL_TYPE_BASELINE_BEAT,
    MODEL_TYPE_BASELINE_BEAT_MEASURE,
    # True incremental models (Round 17)
    MODEL_TYPE_NOTE_ONLY,
    MODEL_TYPE_NOTE_VOICE,
    MODEL_TYPE_NOTE_VOICE_BEAT,
)

# Round 17 Configuration: True Incremental Architecture
# 
# KEY INSIGHT: The original HAN architecture has note and voice LSTMs
# process the SAME embedded input IN PARALLEL, then concatenate.
# The previous "incremental" approach (7-layer LSTM + beat) was wrong.
#
# True incremental progression:
# 1. NoteOnly: 2-layer note BiLSTM (expected R2 ~0.10)
# 2. NoteVoice: + parallel 2-layer voice LSTM (expected R2 ~0.15)
# 3. NoteVoiceBeat: + beat hierarchy (expected R2 ~0.25-0.30)
# 4. Full HAN: + measure hierarchy (expected R2 ~0.35-0.40)

CONFIG = {
    # K-Fold settings
    'n_folds': N_FOLDS,
    'test_ratio': TEST_RATIO,
    # Data
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    'input_size': 79,
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'learning_rate': 2.5e-5,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 8,
    'max_epochs': 200,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '32',
    'max_notes': 5000,
    'slice_len': 5000,
    'num_workers': 4,
    'augment_train': False,
    # Attention improvement hyperparameters
    'attention_lr_multiplier': 10.0,  # Higher LR for attention params
    'entropy_weight': 0.01,           # Penalty for uniform attention
    'entropy_target': 0.6,            # Target entropy (0=focused, 1=uniform)
}

print("="*60)
print("TRAINING CONFIGURATION (ROUND 17 - TRUE INCREMENTAL)")
print("="*60)
print("\nRound 17 Changes:")
print("  - Fixed architecture: note + voice process SAME input in PARALLEL")
print("  - New true incremental models: NoteOnly, NoteVoice, NoteVoiceBeat")
print("  - DiagnosticCallback runs automatically at end of training")
print("  - Model-specific diagnostic insights for each architecture")
print("\n" + "-"*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

print(f"\nModel types available:")
print(f"  MODEL_TYPE_BASELINE = '{MODEL_TYPE_BASELINE}' (7-layer Bi-LSTM, R2 ~0.19)")
print(f"  MODEL_TYPE_NOTE_ONLY = '{MODEL_TYPE_NOTE_ONLY}' (2-layer note, R2 ~0.10)")
print(f"  MODEL_TYPE_NOTE_VOICE = '{MODEL_TYPE_NOTE_VOICE}' (note + voice parallel, R2 ~0.15)")
print(f"  MODEL_TYPE_NOTE_VOICE_BEAT = '{MODEL_TYPE_NOTE_VOICE_BEAT}' (+ beat, R2 ~0.25-0.30)")
print(f"  MODEL_TYPE_HAN = '{MODEL_TYPE_HAN}' (full hierarchical, R2 ~0.40)")

## Step 6: Initialize True Incremental Trainers (Round 17)

Create trainers for the TRUE incremental models that match the actual HAN architecture:
1. NoteOnly (2-layer note BiLSTM) - expected R2 ~0.10
2. NoteVoice (+ parallel 2-layer voice LSTM) - expected R2 ~0.15
3. NoteVoiceBeat (+ beat hierarchy) - expected R2 ~0.25-0.30
4. Full HAN (+ measure hierarchy) - expected R2 ~0.35-0.40

**Key Insight**: The original HAN has note and voice LSTMs process the SAME embedded input IN PARALLEL, then concatenate. This is different from the old "incremental" approach which added beat on top of a 7-layer LSTM.

In [None]:
from src.percepiano.training.kfold_trainer import (
    KFoldTrainer,
    MODEL_TYPE_BASELINE,
    MODEL_TYPE_NOTE_ONLY,
    MODEL_TYPE_NOTE_VOICE,
    MODEL_TYPE_NOTE_VOICE_BEAT,
    MODEL_TYPE_HAN,
)
import pytorch_lightning as pl

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Round 17: True Incremental Training on Fold 2
# This matches the ACTUAL HAN architecture progression
FOLD_ID = 2  # Best performing fold with longest pieces (1-54 beats)

print("="*60)
print("ROUND 17: TRUE INCREMENTAL TRAINERS")
print("="*60)
print(f"\nTraining Fold: {FOLD_ID}")
print("\nTrue Incremental Models (matching HAN architecture):")
print("  1. NoteOnly (2-layer note BiLSTM) - R2 ~0.10")
print("  2. NoteVoice (+ parallel voice LSTM) - R2 ~0.15")
print("  3. NoteVoiceBeat (+ beat hierarchy) - R2 ~0.25-0.30")
print("  4. Full HAN (+ measure hierarchy) - R2 ~0.35-0.40")

# 1. NoteOnly trainer (2-layer note BiLSTM)
print("\n[1] NoteOnly Trainer (2-layer note BiLSTM):")
note_only_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT / "note_only",
    log_dir=LOG_ROOT / "note_only",
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_NOTE_ONLY,
)

# 2. NoteVoice trainer (+ parallel voice LSTM)
print("\n[2] NoteVoice Trainer (+ parallel voice LSTM):")
note_voice_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT / "note_voice",
    log_dir=LOG_ROOT / "note_voice",
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_NOTE_VOICE,
)

# 3. NoteVoiceBeat trainer (+ beat hierarchy)
print("\n[3] NoteVoiceBeat Trainer (+ beat hierarchy):")
note_voice_beat_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT / "note_voice_beat",
    log_dir=LOG_ROOT / "note_voice_beat",
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_NOTE_VOICE_BEAT,
)

# 4. Full HAN trainer (+ measure hierarchy)
print("\n[4] Full HAN Trainer (+ measure hierarchy):")
han_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT / "han",
    log_dir=LOG_ROOT / "han",
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_HAN,
)

print("\n" + "="*60)
print("All 4 trainers initialized!")
print("Checkpoints will be saved to:")
print(f"  NoteOnly: {note_only_trainer.checkpoint_dir}")
print(f"  NoteVoice: {note_voice_trainer.checkpoint_dir}")
print(f"  NoteVoiceBeat: {note_voice_beat_trainer.checkpoint_dir}")
print(f"  Full HAN: {han_trainer.checkpoint_dir}")
print("\nDiagnostics run automatically at end of each training!")
print("="*60)

In [None]:
# Optional: Enable hierarchy debug mode for detailed logging
# This logs boundary detection, attention weights, and other internals

from src.percepiano.models.hierarchy_utils import set_hierarchy_debug

# Uncomment to enable debug mode:
# set_hierarchy_debug(True)

print("Hierarchy debug mode: DISABLED")
print("To enable, uncomment set_hierarchy_debug(True) above")
print("\nNote: Comprehensive diagnostics run automatically at end of each training!")
print("Look for 'END OF TRAINING DIAGNOSTICS' in the output.")

## Step 7: True Incremental Training (Round 17)

Train the TRUE incremental models that match the actual HAN architecture:
1. **NoteOnly** (2-layer note BiLSTM) - R2 ~0.10
2. **NoteVoice** (+ parallel voice LSTM) - R2 ~0.15 (gain ~+0.05 from voice)
3. **NoteVoiceBeat** (+ beat hierarchy) - R2 ~0.25-0.30 (gain ~+0.10 from beat)
4. **Full HAN** (+ measure hierarchy) - R2 ~0.35-0.40 (gain ~+0.05 from measure)

Each training run will automatically output comprehensive diagnostics at the end,
including activation variances, attention entropy, and model-specific insights.

In [None]:
"""
ROUND 17: TRUE INCREMENTAL TRAINING

Train the models that match the ACTUAL HAN architecture progression:
1. NoteOnly (2-layer note BiLSTM) - R2 ~0.10
2. NoteVoice (+ parallel voice LSTM) - R2 ~0.15
3. NoteVoiceBeat (+ beat hierarchy) - R2 ~0.25-0.30
4. Full HAN (+ measure hierarchy) - R2 ~0.35-0.40

Diagnostics run automatically at the end of each training via on_fit_end callback.
"""

print("="*70)
print("ROUND 17: TRUE INCREMENTAL TRAINING")
print("="*70)
print("\nGoal: Validate each hierarchy level contributes to performance")
print("Expected gains:")
print("  Voice (parallel LSTM): ~+0.05 R2")
print("  Beat hierarchy: ~+0.10 R2")
print("  Measure hierarchy: ~+0.05 R2")
print("  Total gain from NoteOnly to HAN: ~+0.25-0.30 R2")
print("\nNote: Diagnostics run automatically at end of each training!")
print("="*70)

# ========================================
# Model 1: NoteOnly (2-layer note BiLSTM)
# ========================================
print("\n" + "="*70)
print("MODEL 1: NoteOnly (2-layer note BiLSTM)")
print("="*70)
print("Expected R2: ~0.10")

note_only_metrics = note_only_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
note_only_trainer.save_results()

print(f"\nNoteOnly training complete! Val R2 = {note_only_metrics.val_r2:+.4f}")

# ========================================
# Model 2: NoteVoice (+ parallel voice LSTM)
# ========================================
print("\n" + "="*70)
print("MODEL 2: NoteVoice (+ parallel voice LSTM)")
print("="*70)
print("Expected R2: ~0.15 (voice contribution: +0.05)")

note_voice_metrics = note_voice_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
note_voice_trainer.save_results()

voice_gain = note_voice_metrics.val_r2 - note_only_metrics.val_r2
print(f"\nNoteVoice training complete!")
print(f"  Val R2: {note_voice_metrics.val_r2:+.4f}")
print(f"  Voice gain: {voice_gain:+.4f} (expected: ~+0.05)")

# ========================================
# Model 3: NoteVoiceBeat (+ beat hierarchy)
# ========================================
print("\n" + "="*70)
print("MODEL 3: NoteVoiceBeat (+ beat hierarchy)")
print("="*70)
print("Expected R2: ~0.25-0.30 (beat contribution: +0.10)")

note_voice_beat_metrics = note_voice_beat_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
note_voice_beat_trainer.save_results()

beat_gain = note_voice_beat_metrics.val_r2 - note_voice_metrics.val_r2
print(f"\nNoteVoiceBeat training complete!")
print(f"  Val R2: {note_voice_beat_metrics.val_r2:+.4f}")
print(f"  Beat gain: {beat_gain:+.4f} (expected: ~+0.10)")

# ========================================
# Model 4: Full HAN (+ measure hierarchy)
# ========================================
print("\n" + "="*70)
print("MODEL 4: Full HAN (+ measure hierarchy)")
print("="*70)
print("Expected R2: ~0.35-0.40 (approaching SOTA of 0.397)")

han_metrics = han_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
han_trainer.save_results()

measure_gain = han_metrics.val_r2 - note_voice_beat_metrics.val_r2
print(f"\nFull HAN training complete!")
print(f"  Val R2: {han_metrics.val_r2:+.4f}")
print(f"  Measure gain: {measure_gain:+.4f} (expected: ~+0.05)")

# ========================================
# Store models and metrics for analysis
# ========================================
trained_models = {
    'note_only': note_only_trainer.get_trained_model(FOLD_ID),
    'note_voice': note_voice_trainer.get_trained_model(FOLD_ID),
    'note_voice_beat': note_voice_beat_trainer.get_trained_model(FOLD_ID),
    'han': han_trainer.get_trained_model(FOLD_ID),
}

trained_metrics = {
    'note_only': note_only_metrics,
    'note_voice': note_voice_metrics,
    'note_voice_beat': note_voice_beat_metrics,
    'han': han_metrics,
}

# ========================================
# Training Summary
# ========================================
print("\n" + "="*70)
print("TRUE INCREMENTAL TRAINING COMPLETE")
print("="*70)

total_gain = han_metrics.val_r2 - note_only_metrics.val_r2

print(f"\n  {'Model':<30} {'Val R2':>10} {'Gain':>10} {'Expected':>12}")
print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
print(f"  {'NoteOnly (2L note)':<30} {note_only_metrics.val_r2:>+10.4f} {'-':>10} {'~0.10':>12}")
print(f"  {'NoteVoice (+ voice)':<30} {note_voice_metrics.val_r2:>+10.4f} {voice_gain:>+10.4f} {'~0.15':>12}")
print(f"  {'NoteVoiceBeat (+ beat)':<30} {note_voice_beat_metrics.val_r2:>+10.4f} {beat_gain:>+10.4f} {'~0.25-0.30':>12}")
print(f"  {'Full HAN (+ measure)':<30} {han_metrics.val_r2:>+10.4f} {measure_gain:>+10.4f} {'~0.35-0.40':>12}")
print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
print(f"  {'Total Gain (NoteOnly->HAN)':<30} {'-':>10} {total_gain:>+10.4f} {'~+0.25':>12}")

print(f"\n  Target: R2 = 0.397 (PercePiano SOTA)")

if han_metrics.val_r2 >= 0.35:
    print(f"  [SUCCESS] Approaching SOTA performance!")
elif han_metrics.val_r2 >= 0.30:
    print(f"  [GOOD] Strong performance, minor tuning may help")
elif han_metrics.val_r2 >= 0.25:
    print(f"  [PARTIAL] Hierarchy helping but below target - check diagnostics")
else:
    print(f"  [ISSUE] Below expected - review end-of-training diagnostics above")

# Validate incremental gains
print("\n  Incremental Validation:")
if voice_gain > 0:
    print(f"    Voice contribution: POSITIVE (+{voice_gain:.4f})")
else:
    print(f"    Voice contribution: NEGATIVE ({voice_gain:.4f}) - INVESTIGATE")

if beat_gain > 0:
    print(f"    Beat contribution: POSITIVE (+{beat_gain:.4f})")
else:
    print(f"    Beat contribution: NEGATIVE ({beat_gain:.4f}) - INVESTIGATE")

if measure_gain > 0:
    print(f"    Measure contribution: POSITIVE (+{measure_gain:.4f})")
else:
    print(f"    Measure contribution: NEGATIVE ({measure_gain:.4f}) - may need longer pieces")

print("="*70)

## Step 8: Post-Training Diagnostics (Optional)

**Note**: Comprehensive diagnostics now run automatically at the end of each model's training via the `on_fit_end` callback. The output includes:
- Activation variances at each level
- Attention entropy analysis
- Model-specific insights with expected R2 values

This cell provides **additional manual diagnostics** if you need to re-run analysis or investigate specific issues after training.

In [None]:
"""
POST-TRAINING DIAGNOSTICS (OPTIONAL)

Diagnostics already ran automatically at end of each training.
This cell provides additional manual analysis if needed.

Key metrics to check from automatic diagnostics:
1. Beat attention entropy: 0.3-0.8 is good; >0.95 means uniform (not learning)
2. Activation variances: Should be in "OK" range, not "LOW - ISSUE"
3. Incremental gains: Each level should improve over previous
"""

import torch
import numpy as np
from torch.utils.data import DataLoader
from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.training.diagnostics import DiagnosticCallback

print("="*70)
print("POST-TRAINING DIAGNOSTICS (OPTIONAL)")
print("="*70)
print("\nNote: Diagnostics already ran automatically at end of each training!")
print("This cell is for additional analysis if needed.")

# Check if we have trained models
if 'trained_models' not in dir():
    print("\n[WARNING] No trained_models found - run training cell first!")
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Prepare validation batch for manual diagnostics
    val_ds = PercePianoKFoldDataset(
        data_dir=DATA_ROOT,
        fold_assignments=fold_assignments,
        fold_id=FOLD_ID,
        mode="val",
        max_notes=CONFIG['max_notes'],
        slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
    )
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0)
    
    # Create diagnostic callback for manual analysis
    diag = DiagnosticCallback()
    
    print("\nModels available for manual diagnostics:")
    for name, model in trained_models.items():
        if model is not None:
            model_type = diag._detect_model_type(model)
            print(f"  - {name}: detected as '{model_type}'")
        else:
            print(f"  - {name}: not available")
    
    print("\nTo run manual diagnostics on a specific model:")
    print("  model = trained_models['note_voice_beat'].to(device).eval()")
    print("  batch = next(iter(val_loader))")
    print("  stats = diag._run_diagnostic_forward(model, {k: v.to(device) for k, v in batch.items()})")
    print("  diag._print_diagnostic_summary(stats, 0)")

print("\n" + "="*70)
print("INTERPRETATION GUIDE")
print("="*70)
print("""
Key metrics from automatic diagnostics:

[1] ACTIVATION VARIANCES:
    - Should show "OK" status for most components
    - "LOW - ISSUE" indicates signal collapse
    - predictions_std should be 0.10-0.25 (not 0.008!)

[2] ATTENTION ENTROPY (for beat/measure models):
    - 0.3-0.8: Good (attention is learning to focus)
    - >0.95: Too uniform (attention not learning)
    - <0.1: Too collapsed (may be overfitting)

[3] HIERARCHY CONTRIBUTION (for beat/measure models):
    - beat_spanned should contribute >10%
    - measure_spanned should contribute >5%
    - Low contribution = model ignoring that branch

[5] MODEL-SPECIFIC INSIGHTS:
    - Shows expected R2 for each model type
    - Highlights specific issues for that architecture
""")

## Step 9: Comprehensive Analysis

Compare all 4 true incremental models:
1. Overall metrics table with incremental gains
2. Per-dimension R2 comparison
3. Hierarchy contribution analysis
4. Comparison to PercePiano paper
5. Next steps recommendations

In [None]:
"""
ROUND 17: COMPREHENSIVE ANALYSIS

Compare all four true incremental models and analyze hierarchy contribution.
"""

import numpy as np
from sklearn.metrics import r2_score
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

print("="*80)
print("ROUND 17 COMPREHENSIVE ANALYSIS")
print("="*80)

# Validate we have metrics
if 'trained_metrics' not in dir():
    raise RuntimeError("No trained_metrics found - run training cell first!")

note_only_m = trained_metrics.get('note_only')
note_voice_m = trained_metrics.get('note_voice')
note_voice_beat_m = trained_metrics.get('note_voice_beat')
han_m = trained_metrics.get('han')

if not all([note_only_m, note_voice_m, note_voice_beat_m, han_m]):
    missing = [k for k, v in [('note_only', note_only_m), ('note_voice', note_voice_m), 
                               ('note_voice_beat', note_voice_beat_m), ('han', han_m)] if v is None]
    print(f"\n[WARNING] Missing metrics for: {missing}")
    print("Some analysis sections will be incomplete.")

# ========================================
# Section 1: Overall Metrics Table
# ========================================
print("\n" + "-"*80)
print("SECTION 1: OVERALL METRICS")
print("-"*80)

print(f"\n  {'Model':<32} {'Val R2':>10} {'Epochs':>8} {'Gain':>10} {'Expected':>12}")
print(f"  {'-'*32} {'-'*10} {'-'*8} {'-'*10} {'-'*12}")

if note_only_m:
    print(f"  {'NoteOnly (2L note)':<32} {note_only_m.val_r2:>+10.4f} {note_only_m.epochs_trained:>8} {'-':>10} {'~0.10':>12}")

if note_voice_m:
    voice_gain = note_voice_m.val_r2 - note_only_m.val_r2 if note_only_m else 0
    print(f"  {'NoteVoice (+ voice)':<32} {note_voice_m.val_r2:>+10.4f} {note_voice_m.epochs_trained:>8} {voice_gain:>+10.4f} {'~0.15':>12}")

if note_voice_beat_m:
    beat_gain = note_voice_beat_m.val_r2 - note_voice_m.val_r2 if note_voice_m else 0
    print(f"  {'NoteVoiceBeat (+ beat)':<32} {note_voice_beat_m.val_r2:>+10.4f} {note_voice_beat_m.epochs_trained:>8} {beat_gain:>+10.4f} {'~0.25':>12}")

if han_m:
    measure_gain = han_m.val_r2 - note_voice_beat_m.val_r2 if note_voice_beat_m else 0
    print(f"  {'Full HAN (+ measure)':<32} {han_m.val_r2:>+10.4f} {han_m.epochs_trained:>8} {measure_gain:>+10.4f} {'~0.35-0.40':>12}")

if note_only_m and han_m:
    total_gain = han_m.val_r2 - note_only_m.val_r2
    print(f"  {'-'*32} {'-'*10} {'-'*8} {'-'*10} {'-'*12}")
    print(f"  {'Total Gain (NoteOnly->HAN)':<32} {'-':>10} {'-':>8} {total_gain:>+10.4f} {'~+0.25':>12}")

# ========================================
# Section 2: Per-Dimension R2 Comparison
# ========================================
print("\n" + "-"*80)
print("SECTION 2: PER-DIMENSION R2 COMPARISON")
print("-"*80)

# Check if we have per-dimension metrics
has_per_dim = all([
    m and hasattr(m, 'per_dim_r2') and m.per_dim_r2
    for m in [note_only_m, note_voice_m, note_voice_beat_m, han_m]
])

if has_per_dim:
    print(f"\n  {'Dimension':<20} {'NoteOnly':>10} {'+ Voice':>10} {'+ Beat':>10} {'Full HAN':>10}")
    print(f"  {'-'*20} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
    
    dim_data = []
    for dim in PERCEPIANO_DIMENSIONS:
        no_r2 = note_only_m.per_dim_r2.get(dim, 0)
        nv_r2 = note_voice_m.per_dim_r2.get(dim, 0)
        nvb_r2 = note_voice_beat_m.per_dim_r2.get(dim, 0)
        han_r2 = han_m.per_dim_r2.get(dim, 0)
        dim_data.append((dim, no_r2, nv_r2, nvb_r2, han_r2))
    
    # Sort by final HAN R2 (best first)
    dim_data.sort(key=lambda x: x[4], reverse=True)
    
    for dim, no_r2, nv_r2, nvb_r2, han_r2 in dim_data:
        print(f"  {dim:<20} {no_r2:>+10.4f} {nv_r2:>+10.4f} {nvb_r2:>+10.4f} {han_r2:>+10.4f}")
else:
    print("\n  [Per-dimension R2 not available - check trainer saves per_dim_r2]")

# ========================================
# Section 3: Hierarchy Contribution Analysis
# ========================================
print("\n" + "-"*80)
print("SECTION 3: HIERARCHY CONTRIBUTION ANALYSIS")
print("-"*80)

if all([note_only_m, note_voice_m, note_voice_beat_m, han_m]):
    voice_contrib = note_voice_m.val_r2 - note_only_m.val_r2
    beat_contrib = note_voice_beat_m.val_r2 - note_voice_m.val_r2
    measure_contrib = han_m.val_r2 - note_voice_beat_m.val_r2
    total_hierarchy = han_m.val_r2 - note_only_m.val_r2
    
    print(f"\n  {'Component':<30} {'Actual':>10} {'Expected':>10} {'Status':>12}")
    print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
    
    # Voice contribution
    status = "[OK]" if voice_contrib >= 0.03 else "[LOW]" if voice_contrib >= 0 else "[NEGATIVE]"
    print(f"  {'Voice (parallel LSTM)':<30} {voice_contrib:>+10.4f} {'~+0.05':>10} {status:>12}")
    
    # Beat contribution
    status = "[OK]" if beat_contrib >= 0.07 else "[LOW]" if beat_contrib >= 0 else "[NEGATIVE]"
    print(f"  {'Beat hierarchy':<30} {beat_contrib:>+10.4f} {'~+0.10':>10} {status:>12}")
    
    # Measure contribution
    status = "[OK]" if measure_contrib >= 0.03 else "[LOW]" if measure_contrib >= 0 else "[NEGATIVE]"
    print(f"  {'Measure hierarchy':<30} {measure_contrib:>+10.4f} {'~+0.05':>10} {status:>12}")
    
    print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
    status = "[OK]" if total_hierarchy >= 0.20 else "[PARTIAL]" if total_hierarchy >= 0.10 else "[LOW]"
    print(f"  {'Total hierarchy gain':<30} {total_hierarchy:>+10.4f} {'~+0.25':>10} {status:>12}")

# ========================================
# Section 4: Comparison to PercePiano Paper
# ========================================
print("\n" + "-"*80)
print("SECTION 4: COMPARISON TO PERCEPIANO PAPER")
print("-"*80)

print(f"\n  {'Model':<32} {'Paper R2':>10} {'Our R2':>10} {'Match':>10}")
print(f"  {'-'*32} {'-'*10} {'-'*10} {'-'*10}")

# Note: Paper doesn't report intermediate models, but we can compare endpoints
if note_only_m:
    # No direct comparison for NoteOnly
    print(f"  {'NoteOnly (our baseline)':<32} {'N/A':>10} {note_only_m.val_r2:>+10.4f} {'-':>10}")

if han_m:
    if han_m.val_r2 >= 0.35:
        status = "[OK]"
    elif han_m.val_r2 >= 0.30:
        status = "[CLOSE]"
    elif han_m.val_r2 >= 0.25:
        status = "[PARTIAL]"
    else:
        status = "[LOW]"
    print(f"  {'VirtuosoNetMultiLevel (HAN)':<32} {'0.397':>10} {han_m.val_r2:>+10.4f} {status:>10}")

if note_only_m and han_m:
    total_gain = han_m.val_r2 - note_only_m.val_r2
    # Paper shows ~+0.21 gain from baseline to HAN
    status = "[GOOD]" if total_gain >= 0.20 else "[PARTIAL]" if total_gain >= 0.10 else "[LOW]"
    print(f"  {'Hierarchy Gain (paper: +0.21)':<32} {'+0.212':>10} {total_gain:>+10.4f} {status:>10}")

# ========================================
# Section 5: Next Steps Recommendations
# ========================================
print("\n" + "-"*80)
print("SECTION 5: NEXT STEPS RECOMMENDATIONS")
print("-"*80)

if all([note_only_m, note_voice_m, note_voice_beat_m, han_m]):
    voice_contrib = note_voice_m.val_r2 - note_only_m.val_r2
    beat_contrib = note_voice_beat_m.val_r2 - note_voice_m.val_r2
    measure_contrib = han_m.val_r2 - note_voice_beat_m.val_r2
    total_hierarchy = han_m.val_r2 - note_only_m.val_r2
    
    print("\n  Based on Round 17 results:")
    
    if han_m.val_r2 >= 0.35:
        print("\n  [SUCCESS] Approaching SOTA performance!")
        print("    - Model architecture is working correctly")
        print("    - Consider training on all 4 folds for robust evaluation")
        print("    - Ready to explore enhancements (MERT embeddings, etc.)")
    elif han_m.val_r2 >= 0.25:
        print("\n  [GOOD] Hierarchy contributing, some room for improvement")
        print("    - Check attention entropy in automatic diagnostics")
        print("    - Consider longer training or LR tuning")
        print("    - May proceed with experiments")
    else:
        print("\n  [INVESTIGATE] Performance below expected")
        
        if voice_contrib < 0:
            print("    - Voice contribution NEGATIVE: Check voice LSTM implementation")
        if beat_contrib < 0.05:
            print("    - Beat contribution LOW: Check beat attention entropy (should be 0.3-0.8)")
        if measure_contrib < 0:
            print("    - Measure contribution NEGATIVE: May need longer pieces in training fold")
        
        print("\n    Review the END OF TRAINING DIAGNOSTICS output for each model")
        print("    Look for: [WARNING] messages and [LOW - ISSUE] activations")

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)

## Step 10: Sync Checkpoints to Google Drive

Sync all 4 true incremental model checkpoints to Google Drive.

In [None]:
"""
ROUND 17: SYNC ALL CHECKPOINTS TO GOOGLE DRIVE
"""

import subprocess

print("="*60)
print("SYNC ALL ROUND 17 CHECKPOINTS TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    # Sync all model checkpoints
    model_dirs = [
        ('note_only', note_only_trainer.checkpoint_dir),
        ('note_voice', note_voice_trainer.checkpoint_dir),
        ('note_voice_beat', note_voice_beat_trainer.checkpoint_dir),
        ('han', han_trainer.checkpoint_dir),
    ]
    
    for model_name, ckpt_dir in model_dirs:
        if ckpt_dir.exists():
            gdrive_path = f"{GDRIVE_CHECKPOINT_PATH}/{ckpt_dir.name}"
            print(f"\nSyncing {model_name} ({ckpt_dir.name})...")
            subprocess.run(
                ['rclone', 'copy', str(ckpt_dir), gdrive_path, '--progress'],
                capture_output=False
            )
        else:
            print(f"\n[SKIP] {model_name} checkpoint dir not found: {ckpt_dir}")
    
    # Sync fold assignments
    print(f"\nSyncing fold assignments...")
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print("\n" + "="*60)
    print("SYNC COMPLETE")
    print("="*60)
    print(f"\nCheckpoints synced to: {GDRIVE_CHECKPOINT_PATH}")
    print("  - note_only/")
    print("  - note_voice/")
    print("  - note_voice_beat/")
    print("  - han/")
    print(f"\nFold assignments synced to: {GDRIVE_DATA_PATH}")
else:
    print("\nrclone not available - skipping sync")
    print("Checkpoints saved locally at:", CHECKPOINT_ROOT)

print("\n" + "="*60)
print("ROUND 17 COMPLETE")
print("="*60)