# PercePiano Replica Training (4-Fold Cross-Validation)

Train the PercePiano replica model using 4-fold piece-based cross-validation,
matching the methodology and hyperparameters from the PercePiano paper SOTA.

## Attribution

> **PercePiano: Piano Performance Evaluation Dataset with Multi-level Perceptual Features**  
> Park, Kim et al.  
> Nature Scientific Reports 2024  
> Paper: https://pmc.ncbi.nlm.nih.gov/articles/PMC11450231/  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Methodology

Following the exact approach from `m2pf_dataset_compositionfold.py`:

- **Piece-based splits**: All performances of the same piece stay in the same fold
- **Test set**: Select pieces randomly until reaching ~15% of SAMPLES (not pieces)
- **4-fold CV**: Remaining pieces distributed round-robin across folds
- **Per-fold normalization**: Stats computed from training folds only

## Hyperparameters (SOTA Configuration - R2 = 0.397)

These parameters match the published SOTA from `2_run_comp_multilevel_total.sh` and `han_bigger256_concat.yml`:

| Parameter | SOTA Value | Notes |
|-----------|------------|-------|
| input_size | 79 | SOTA uses 79 base features (includes section_tempo) |
| batch_size | 8 | From SOTA training script |
| learning_rate | 2.5e-5 | From SOTA training script |
| hidden_size | 256 | HAN encoder dimension |
| prediction_head | 512->512->19 | From model_m2pf.py (NOT config's final_fc_size) |
| dropout | 0.2 | Regularization |
| augment_train | False | SOTA doesn't use key augmentation |
| max_epochs | 200 | Extended training window |
| early_stopping_patience | 40 | More patience for convergence |
| gradient_clip_val | 2.0 | From parser.py |
| **precision** | **32** | **FP32 (original uses FP32, not mixed precision)** |
| **max_notes/slice_len** | **5000** | **SOTA slice size for overlapping sampling** |

## Key Fixes Applied (Round 8 - 2025-12-26)

### CRITICAL: Slice Sampling (Round 8)

The single most impactful discrepancy identified:

| Aspect | Original PercePiano | Previous Implementation | Impact |
|--------|---------------------|------------------------|--------|
| Samples/performance | 3-5 overlapping slices | 1 sample | 3-5x less training data |
| Training samples | ~600-1000 slices | ~200 samples | Critical for learning |
| Slice regeneration | Each epoch | None | No variation |

**Fix**: Added `make_slicing_indexes_by_measure()` and `SliceRegenerationCallback`.

### Round History

| Round | Changes | Result |
|-------|---------|--------|
| 1-5 | Various fixes (precision, attention, init) | R2 stuck around 0 |
| 6 | Match original architecture (no LayerNorm, 512->512, LR 2.5e-5) | Zero context_vector gradients |
| 7 | Fix data pipeline (79 features, PackedSequence) | R2 = 0.0017 (prediction collapse) |
| **8** | **Add SLICE SAMPLING (3-5 slices/sample, epoch regeneration)** | **Pending** |

See `docs/EXPERIMENT_LOG.md` for full investigation details.

## Expected Results

- Target R2: 0.35-0.40 (matching published SOTA of 0.397)
- Training time: ~8-12 hours on T4, ~3-5 hours on A100 (all 4 folds)
- With slice sampling: ~600-1000 training slices (vs ~200 samples before)

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Clone original PercePiano for comparison (needed for data diagnostics)
PERCEPIANO_PATH = '/tmp/crescendai/model/data/raw/PercePiano'
if not os.path.exists(PERCEPIANO_PATH):
    print("\nCloning original PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {PERCEPIANO_PATH}
else:
    print(f"\nPercePiano already present at {PERCEPIANO_PATH}")

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

# Training control
RESTART_TRAINING = True  # Set to True to clear checkpoints and start fresh

print("="*60)
print("PERCEPIANO REPLICA TRAINING (4-FOLD CV)")
print("="*60)

# Clear checkpoints if restarting
if RESTART_TRAINING and CHECKPOINT_ROOT.exists():
    print(f"\nRESTART_TRAINING=True: Clearing checkpoints at {CHECKPOINT_ROOT}")
    shutil.rmtree(CHECKPOINT_ROOT)
    print("  Checkpoints cleared!")

if RESTART_TRAINING and LOG_ROOT.exists():
    print(f"RESTART_TRAINING=True: Clearing logs at {LOG_ROOT}")
    shutil.rmtree(LOG_ROOT)
    print("  Logs cleared!")

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("\nrclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("\nrclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")
print(f"Log directory: {LOG_ROOT}")
print(f"\nRESTART_TRAINING: {RESTART_TRAINING}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

fold_file = DATA_ROOT / 'fold_assignments.json'
print(f"  fold_assignments.json: {'present' if fold_file.exists() else 'will be created'}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
TEST_RATIO = 0.15
SEED = 42

print("="*60)
print("FOLD ASSIGNMENT CREATION")
print("="*60)

# Force regeneration to use corrected methodology
# - Test set: select pieces until ~15% of SAMPLES (PercePiano methodology)
# - CV folds: greedy bin-packing for balanced sample counts (improvement over round-robin)
FORCE_REGENERATE = True

if FOLD_FILE.exists() and not FORCE_REGENERATE:
    print(f"\nLoading existing fold assignments from {FOLD_FILE}")
    fold_assignments = load_fold_assignments(FOLD_FILE)
else:
    if FOLD_FILE.exists():
        print(f"\nRemoving old fold assignments (regenerating with balanced methodology)...")
        FOLD_FILE.unlink()
    
    print(f"\nCreating new {N_FOLDS}-fold piece-based splits...")
    print("  Test set: select pieces until ~15% of SAMPLES")
    print("  CV folds: greedy bin-packing for balanced sample counts")
    fold_assignments = create_piece_based_folds(
        data_dir=DATA_ROOT,
        n_folds=N_FOLDS,
        test_ratio=TEST_RATIO,
        seed=SEED,
    )
    save_fold_assignments(fold_assignments, FOLD_FILE)

# Print statistics
print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Enable better CUDA error reporting
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import model type constants
from src.percepiano.training.kfold_trainer import MODEL_TYPE_HAN, MODEL_TYPE_BASELINE

CONFIG = {
    # K-Fold settings
    'n_folds': N_FOLDS,
    'test_ratio': TEST_RATIO,
    # Data
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    'input_size': 79,
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'learning_rate': 2.5e-5,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 8,
    'max_epochs': 200,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '32',
    'max_notes': 5000,
    'slice_len': 5000,
    'num_workers': 4,
    'augment_train': False,
}

print("="*60)
print("TRAINING CONFIGURATION (SOTA - ROUND 13)")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

print(f"\nModel types available:")
print(f"  MODEL_TYPE_BASELINE = '{MODEL_TYPE_BASELINE}' (7-layer Bi-LSTM, expected R2 ~0.19)")
print(f"  MODEL_TYPE_HAN = '{MODEL_TYPE_HAN}' (Hierarchical, expected R2 ~0.40)")

## Step 5b: Pre-Training Data Diagnostics (CRITICAL)

Run this BEFORE training to validate data and detect index issues that can break hierarchy.

In [None]:
"""
PRE-TRAINING DATA DIAGNOSTICS

This cell validates the data pipeline before training to catch issues that
would cause the hierarchical components (beat/measure attention) to fail.

Key checks:
1. Index format (should start from 1 after densification)
2. Zero-shifted values (should have no negatives)
3. Boundary detection (== 1 vs > 0 equivalence)
4. Slice statistics
"""

import torch
import numpy as np
from torch.utils.data import DataLoader
from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.training.diagnostics import analyze_indices

print("=" * 70)
print("PRE-TRAINING DATA DIAGNOSTICS")
print("=" * 70)

# Create a test dataset for fold 0
test_ds = PercePianoKFoldDataset(
    data_dir=DATA_ROOT,
    fold_assignments=fold_assignments,
    fold_id=0,
    mode="train",
    max_notes=CONFIG['max_notes'],
    slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
)

# Create a simple dataloader (no packing for easier analysis)
simple_loader = DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=0)

# Get one batch
batch = next(iter(simple_loader))

print("\n[1] BATCH SHAPE ANALYSIS:")
print(f"  input_features: {batch['input_features'].shape}")
print(f"  note_locations_beat: {batch['note_locations_beat'].shape}")
print(f"  note_locations_measure: {batch['note_locations_measure'].shape}")
print(f"  note_locations_voice: {batch['note_locations_voice'].shape}")
print(f"  scores: {batch['scores'].shape}")
print(f"  num_notes: {[batch['num_notes'][i].item() if hasattr(batch['num_notes'][i], 'item') else batch['num_notes'][i] for i in range(4)]}")

print("\n[2] INDEX ANALYSIS:")
idx_stats = analyze_indices(
    batch['note_locations_beat'],
    batch['note_locations_measure'],
)

print(f"  Beat indices: min={idx_stats['beat_min']}, max={idx_stats['beat_max']}")
print(f"  Measure indices: min={idx_stats['measure_min']}, max={idx_stats['measure_max']}")

# Check if indices start from 1 (required for hierarchy_utils)
if idx_stats['beat_min'] == 1:
    print(f"  [OK] Beat indices start from 1 (required)")
elif idx_stats['beat_min'] == 0:
    print(f"  [WARNING] Beat indices start from 0 (should be 1)")
else:
    print(f"  [ERROR] Beat indices start from {idx_stats['beat_min']} (unexpected)")

print("\n[3] ZERO-SHIFTED INDEX CHECK:")
print(f"  Beat zero-shifted range: [{idx_stats['beat_zero_shifted_min']}, {idx_stats['beat_zero_shifted_max']}]")
print(f"  Measure zero-shifted range: [{idx_stats['measure_zero_shifted_min']}, {idx_stats['measure_zero_shifted_max']}]")

if idx_stats['negative_beat_count'] > 0:
    print(f"  [CRITICAL ERROR] {idx_stats['negative_beat_count']} negative zero-shifted beat values!")
    print(f"    This WILL break span_beat_to_note_num!")
else:
    print(f"  [OK] No negative zero-shifted values")

if idx_stats['negative_measure_count'] > 0:
    print(f"  [WARNING] {idx_stats['negative_measure_count']} negative zero-shifted measure values")

print("\n[4] BOUNDARY DETECTION CHECK:")
print(f"  diff == 1 count: {idx_stats['diff_equals_1_count']}")
print(f"  diff > 0 count: {idx_stats['diff_greater_0_count']}")
print(f"  diff < 0 count: {idx_stats['diff_less_0_count']} (should only be at padding boundary)")

if idx_stats['diff_equals_1_count'] == idx_stats['diff_greater_0_count']:
    print(f"  [OK] == 1 and > 0 are equivalent (indices are sequential)")
else:
    print(f"  [WARNING] == 1 ({idx_stats['diff_equals_1_count']}) != > 0 ({idx_stats['diff_greater_0_count']})")
    print(f"    Indices may not be properly densified!")

if idx_stats['non_sequential_samples'] > 0:
    print(f"  [WARNING] {idx_stats['non_sequential_samples']}/4 samples have non-sequential indices")

print("\n[5] SLICE STATISTICS:")
print(f"  Total slices in dataset: {len(test_ds)}")
print(f"  Underlying samples: {len(test_ds.sample_files)}")
print(f"  Avg slices per sample: {len(test_ds) / len(test_ds.sample_files):.1f}")

print("\n[6] FIRST SAMPLE BEAT INDICES (first 50 values):")
first_beat = batch['note_locations_beat'][0].numpy()
first_num_notes = batch['num_notes'][0] if isinstance(batch['num_notes'][0], int) else batch['num_notes'][0].item()
print(f"  {first_beat[:min(50, first_num_notes)].tolist()}")

# Check for unique beats
unique_beats = np.unique(first_beat[:first_num_notes])
print(f"  Unique beat values: {len(unique_beats)}")
print(f"  First 10 unique beats: {unique_beats[:10].tolist()}")

print("\n[7] INPUT FEATURE STATISTICS:")
features = batch['input_features']
print(f"  Overall: mean={features.mean():.4f}, std={features.std():.4f}")
print(f"  Range: [{features.min():.4f}, {features.max():.4f}]")

# Check for NaN/Inf
nan_count = torch.isnan(features).sum().item()
inf_count = torch.isinf(features).sum().item()
if nan_count > 0:
    print(f"  [ERROR] {nan_count} NaN values detected!")
if inf_count > 0:
    print(f"  [ERROR] {inf_count} Inf values detected!")
if nan_count == 0 and inf_count == 0:
    print(f"  [OK] No NaN or Inf values")

print("\n[8] LABEL STATISTICS:")
scores = batch['scores']
print(f"  Mean per dimension: {scores.mean(dim=0).tolist()[:5]} ... (first 5)")
print(f"  Std per dimension: {scores.std(dim=0).tolist()[:5]} ... (first 5)")
print(f"  Range: [{scores.min():.4f}, {scores.max():.4f}]")

if scores.min() < 0 or scores.max() > 1:
    print(f"  [WARNING] Labels outside [0, 1] range!")
else:
    print(f"  [OK] Labels in [0, 1] range")

print("\n" + "=" * 70)
print("DIAGNOSTIC SUMMARY")
print("=" * 70)

issues = []
if idx_stats['beat_min'] != 1:
    issues.append("Beat indices don't start from 1")
if idx_stats['negative_beat_count'] > 0:
    issues.append(f"Negative zero-shifted values ({idx_stats['negative_beat_count']})")
if idx_stats['diff_equals_1_count'] != idx_stats['diff_greater_0_count']:
    issues.append("Non-sequential indices detected")
if nan_count > 0 or inf_count > 0:
    issues.append("NaN/Inf in features")

if issues:
    print(f"\n[ISSUES FOUND] {len(issues)} issues that may affect training:")
    for i, issue in enumerate(issues, 1):
        print(f"  {i}. {issue}")
    print("\nConsider fixing these before training!")
else:
    print("\n[ALL CHECKS PASSED] Data pipeline looks correct!")
    print("Proceed to training.")

print("=" * 70)

## Step 6: Initialize K-Fold Trainer

In [None]:
from src.percepiano.training.kfold_trainer import KFoldTrainer, MODEL_TYPE_HAN, MODEL_TYPE_BASELINE
import pytorch_lightning as pl

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Create K-Fold trainers for BOTH models
# This allows us to compare Baseline vs HAN directly

print("="*60)
print("INITIALIZING TRAINERS")
print("="*60)

# 1. Baseline trainer (7-layer Bi-LSTM)
print("\n[1] Bi-LSTM Baseline Trainer:")
baseline_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_BASELINE,
)

# 2. HAN trainer (Hierarchical)
print("\n[2] HAN (Hierarchical) Trainer:")
han_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_HAN,
)

print("\n" + "="*60)
print("Both trainers initialized!")
print(f"  Baseline checkpoints: {baseline_trainer.checkpoint_dir}")
print(f"  HAN checkpoints: {han_trainer.checkpoint_dir}")
print("="*60)

## Step 7: Train All Folds

In [None]:
"""
TRAIN BOTH MODELS FOR COMPARISON - ALL 4 FOLDS

This cell trains:
1. Bi-LSTM Baseline (7-layer) - Expected R2 ~0.19
2. HAN (Hierarchical) - Expected R2 ~0.40

The difference (HAN R2 - Baseline R2) is the TRUE hierarchy gain.
Expected hierarchy gain: ~+0.21 R2
"""

print("="*60)
print("TRAINING BOTH MODELS (ALL 4 FOLDS)")
print("="*60)
print("\nPercePiano SOTA baselines:")
print("  Bi-LSTM (VirtuosoNetSingle): R2 = 0.185")
print("  HAN (VirtuosoNetMultiLevel): R2 = 0.397")
print("  Hierarchy gain: +0.212")
print("="*60)

# Train all 4 folds for both models
FOLDS_TO_TRAIN = [1, 2, 3, 4]

# ========================================
# STEP 1: Train Bi-LSTM Baseline (All 4 Folds)
# ========================================
print("\n" + "="*60)
print("STEP 1: TRAINING BI-LSTM BASELINE (ALL 4 FOLDS)")
print("="*60)
print("Expected R2: ~0.19 (matching original VirtuosoNetSingle)")

for fold_id in FOLDS_TO_TRAIN:
    print(f"\n{'='*60}")
    print(f"BASELINE - FOLD {fold_id}/{len(FOLDS_TO_TRAIN)}")
    print(f"{'='*60}")
    
    baseline_metrics = baseline_trainer.train_fold(
        fold_id=fold_id,
        verbose=True,
        resume_from_checkpoint=False,
    )
    
    print(f"\n  Fold {fold_id} Val R2: {baseline_metrics.val_r2:+.4f}")

baseline_trainer.save_results()

# Compute aggregate baseline metrics
baseline_agg = baseline_trainer._compute_aggregate_metrics()
print(f"\n{'='*60}")
print(f"BASELINE COMPLETE - Mean R2: {baseline_agg.mean_r2:+.4f} (+/- {baseline_agg.std_r2:.4f})")
print(f"{'='*60}")

# ========================================
# STEP 2: Train HAN (All 4 Folds)
# ========================================
print("\n" + "="*60)
print("STEP 2: TRAINING HAN (ALL 4 FOLDS)")
print("="*60)
print("Expected R2: ~0.40 (matching original VirtuosoNetMultiLevel)")

for fold_id in FOLDS_TO_TRAIN:
    print(f"\n{'='*60}")
    print(f"HAN - FOLD {fold_id}/{len(FOLDS_TO_TRAIN)}")
    print(f"{'='*60}")
    
    han_metrics = han_trainer.train_fold(
        fold_id=fold_id,
        verbose=True,
        resume_from_checkpoint=False,
    )
    
    print(f"\n  Fold {fold_id} Val R2: {han_metrics.val_r2:+.4f}")

han_trainer.save_results()

# Compute aggregate HAN metrics
han_agg = han_trainer._compute_aggregate_metrics()
print(f"\n{'='*60}")
print(f"HAN COMPLETE - Mean R2: {han_agg.mean_r2:+.4f} (+/- {han_agg.std_r2:.4f})")
print(f"{'='*60}")

# ========================================
# COMPARISON SUMMARY
# ========================================
print("\n" + "="*60)
print("TRAINING COMPLETE - FULL COMPARISON (4-FOLD CV)")
print("="*60)

hierarchy_gain = han_agg.mean_r2 - baseline_agg.mean_r2

print(f"\n  {'Model':<25} {'Mean R2':>10} {'Std R2':>10} {'Expected':>10}")
print(f"  {'-'*25} {'-'*10} {'-'*10} {'-'*10}")
print(f"  {'Bi-LSTM Baseline':<25} {baseline_agg.mean_r2:>+10.4f} {baseline_agg.std_r2:>10.4f} {'~0.19':>10}")
print(f"  {'HAN (Hierarchical)':<25} {han_agg.mean_r2:>+10.4f} {han_agg.std_r2:>10.4f} {'~0.40':>10}")
print(f"  {'-'*25} {'-'*10} {'-'*10} {'-'*10}")
print(f"  {'Hierarchy Gain':<25} {hierarchy_gain:>+10.4f} {'':>10} {'~+0.21':>10}")

print(f"\n  Interpretation:")
if baseline_agg.mean_r2 < 0.10:
    print(f"  [CRITICAL] Baseline R2 ({baseline_agg.mean_r2:.3f}) is very low!")
    print(f"  This suggests a fundamental issue with data or training pipeline.")
elif hierarchy_gain < 0.05:
    print(f"  [WARNING] Hierarchy gain ({hierarchy_gain:.3f}) is too low!")
    print(f"  Baseline works but HAN is not adding value.")
elif hierarchy_gain < 0.15:
    print(f"  [PARTIAL] Hierarchy gain ({hierarchy_gain:.3f}) is below expected (~0.21).")
    print(f"  Some hierarchy contribution, but room for improvement.")
else:
    print(f"  [GOOD] Hierarchy gain ({hierarchy_gain:.3f}) is close to expected (~0.21)!")

print("="*60)

# Store for later use
trained_models = {
    'baseline': baseline_trainer.get_trained_model(FOLDS_TO_TRAIN[-1]),
    'han': han_trainer.get_trained_model(FOLDS_TO_TRAIN[-1]),
}
trained_metrics = {
    'baseline': baseline_agg,
    'han': han_agg,
}

In [None]:
# Sync checkpoints to Google Drive after training
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 8: Test Set Evaluation

In [None]:
# Evaluate both models on held-out test set
print("="*60)
print("TEST SET EVALUATION (BOTH MODELS)")
print("="*60)

print("\n[1] Bi-LSTM Baseline on Test Set:")
baseline_test_results = baseline_trainer.evaluate_on_test(verbose=True)

print("\n[2] HAN on Test Set:")
han_test_results = han_trainer.evaluate_on_test(verbose=True)

# Store for comparison
test_results = {
    'baseline': baseline_test_results,
    'han': han_test_results,
}

## Step 8b: Post-Training Hierarchy Diagnostics

Analyze why the hierarchical components (beat/measure attention) may not be contributing.
This cell loads the best model from fold 1 and runs comprehensive diagnostics.

In [None]:
"""
POST-TRAINING MODEL COMPARISON

Compare Bi-LSTM Baseline vs HAN on the same validation set.
This gives the TRUE hierarchy contribution measurement.
"""

import torch
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import r2_score

print("=" * 70)
print("POST-TRAINING MODEL COMPARISON")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get models from training
if 'trained_models' in dir() and trained_models.get('baseline') and trained_models.get('han'):
    baseline_model = trained_models['baseline'].to(device).eval()
    han_model = trained_models['han'].to(device).eval()
    print(f"\nUsing in-memory trained models")
else:
    print(f"\nNo trained models found - run training cell first!")
    baseline_model = None
    han_model = None

if baseline_model is not None and han_model is not None:
    # Create validation dataloader
    from src.percepiano.data.percepiano_vnet_dataset import (
        PercePianoKFoldDataset,
        percepiano_pack_collate,
    )
    
    val_ds = PercePianoKFoldDataset(
        data_dir=DATA_ROOT,
        fold_assignments=fold_assignments,
        fold_id=FOLD_TO_TRAIN,
        mode="val",
        max_notes=CONFIG['max_notes'],
        slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=0,
        collate_fn=percepiano_pack_collate,
    )
    
    print(f"Validation samples: {len(val_ds)}")
    print(f"Device: {device}")
    
    # Collect predictions from both models
    baseline_preds = []
    han_preds = []
    targets = []
    
    from torch.nn.utils.rnn import PackedSequence
    
    with torch.no_grad():
        for batch in val_loader:
            # Move batch to device
            batch_on_device = {}
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    batch_on_device[k] = v.to(device)
                elif isinstance(v, PackedSequence):
                    batch_on_device[k] = PackedSequence(
                        v.data.to(device),
                        v.batch_sizes,
                        v.sorted_indices.to(device) if v.sorted_indices is not None else None,
                        v.unsorted_indices.to(device) if v.unsorted_indices is not None else None,
                    )
                else:
                    batch_on_device[k] = v
            
            note_locations = {
                'beat': batch_on_device['note_locations_beat'],
                'measure': batch_on_device['note_locations_measure'],
                'voice': batch_on_device['note_locations_voice'],
            }
            
            # Baseline prediction
            baseline_out = baseline_model(
                batch_on_device['input_features'],
                note_locations,
                batch_on_device.get('attention_mask'),
                batch_on_device.get('lengths'),
            )
            baseline_preds.append(baseline_out['predictions'].cpu())
            
            # HAN prediction
            han_out = han_model(
                batch_on_device['input_features'],
                note_locations,
                batch_on_device.get('attention_mask'),
                batch_on_device.get('lengths'),
            )
            han_preds.append(han_out['predictions'].cpu())
            
            targets.append(batch_on_device['scores'].cpu())
    
    # Compute R2 scores
    baseline_preds = torch.cat(baseline_preds).numpy()
    han_preds = torch.cat(han_preds).numpy()
    targets = torch.cat(targets).numpy()
    
    baseline_r2_fresh = r2_score(targets, baseline_preds)
    han_r2_fresh = r2_score(targets, han_preds)
    hierarchy_gain_fresh = han_r2_fresh - baseline_r2_fresh
    
    # Per-dimension analysis
    from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS
    
    print("\n" + "=" * 70)
    print("OVERALL COMPARISON")
    print("=" * 70)
    
    print(f"\n  {'Model':<25} {'R2':>10} {'Expected':>10} {'Status':<15}")
    print(f"  {'-'*25} {'-'*10} {'-'*10} {'-'*15}")
    print(f"  {'Bi-LSTM Baseline':<25} {baseline_r2_fresh:>+10.4f} {'~0.19':>10} ", end="")
    if baseline_r2_fresh >= 0.15:
        print("[OK]")
    elif baseline_r2_fresh >= 0.10:
        print("[LOW]")
    else:
        print("[CRITICAL]")
    
    print(f"  {'HAN (Hierarchical)':<25} {han_r2_fresh:>+10.4f} {'~0.40':>10} ", end="")
    if han_r2_fresh >= 0.35:
        print("[OK]")
    elif han_r2_fresh >= 0.20:
        print("[LOW]")
    else:
        print("[CRITICAL]")
    
    print(f"  {'-'*25} {'-'*10} {'-'*10} {'-'*15}")
    print(f"  {'Hierarchy Gain':<25} {hierarchy_gain_fresh:>+10.4f} {'~+0.21':>10} ", end="")
    if hierarchy_gain_fresh >= 0.15:
        print("[GOOD]")
    elif hierarchy_gain_fresh >= 0.05:
        print("[PARTIAL]")
    else:
        print("[NONE]")
    
    # Per-dimension comparison
    print("\n" + "=" * 70)
    print("PER-DIMENSION R2 COMPARISON")
    print("=" * 70)
    
    print(f"\n  {'Dimension':<30} {'Baseline':>10} {'HAN':>10} {'Gain':>10}")
    print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*10}")
    
    dim_gains = []
    for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
        baseline_dim_r2 = r2_score(targets[:, i], baseline_preds[:, i])
        han_dim_r2 = r2_score(targets[:, i], han_preds[:, i])
        dim_gain = han_dim_r2 - baseline_dim_r2
        dim_gains.append((dim, baseline_dim_r2, han_dim_r2, dim_gain))
    
    # Sort by gain
    dim_gains.sort(key=lambda x: x[3], reverse=True)
    
    for dim, baseline_dim, han_dim, gain in dim_gains:
        indicator = "+" if gain > 0 else ""
        print(f"  {dim:<30} {baseline_dim:>+10.4f} {han_dim:>+10.4f} {indicator}{gain:>9.4f}")
    
    positive_gain_dims = sum(1 for _, _, _, g in dim_gains if g > 0)
    print(f"\n  Dimensions with positive hierarchy gain: {positive_gain_dims}/{len(PERCEPIANO_DIMENSIONS)}")
    
    # Diagnosis
    print("\n" + "=" * 70)
    print("DIAGNOSIS")
    print("=" * 70)
    
    if baseline_r2_fresh < 0.10:
        print(f"\n  [CRITICAL] Baseline R2 = {baseline_r2_fresh:.4f} is very low!")
        print(f"  This indicates a problem in the data pipeline or training loop.")
        print(f"  The issue is NOT specific to HAN - fix the baseline first.")
        print(f"\n  Likely causes:")
        print(f"    1. Data preprocessing issue (feature scaling, normalization)")
        print(f"    2. Label alignment problem")
        print(f"    3. Training instability (learning rate, gradient clipping)")
    elif hierarchy_gain_fresh < 0.05:
        print(f"\n  [WARNING] Baseline works (R2={baseline_r2_fresh:.4f}) but hierarchy gain is low ({hierarchy_gain_fresh:.4f})")
        print(f"  The Bi-LSTM is learning, but HAN hierarchy is not helping.")
        print(f"\n  Likely causes:")
        print(f"    1. Beat/measure attention collapsed to uniform")
        print(f"    2. span_beat_to_note_num index mapping issue")
        print(f"    3. HAN architecture mismatch with original")
    elif han_r2_fresh < 0.30:
        print(f"\n  [PARTIAL] Both models learning, but below SOTA.")
        print(f"  Baseline: {baseline_r2_fresh:.4f} (expected ~0.19)")
        print(f"  HAN: {han_r2_fresh:.4f} (expected ~0.40)")
        print(f"  Hierarchy gain: {hierarchy_gain_fresh:.4f} (expected ~0.21)")
        print(f"\n  Suggestions:")
        print(f"    1. Train for more epochs")
        print(f"    2. Try all 4 folds and average")
        print(f"    3. Check if slice sampling is working correctly")
    else:
        print(f"\n  [GOOD] Results approaching SOTA!")
        print(f"  Baseline: {baseline_r2_fresh:.4f}")
        print(f"  HAN: {han_r2_fresh:.4f}")
        print(f"  Hierarchy gain: {hierarchy_gain_fresh:.4f}")
    
    print("=" * 70)
else:
    print("\n[SKIPPED] No models available for comparison.")
    print("Run the training cell first.")

## Step 8c: Manual Ablation Test

Directly compare full model R2 vs Bi-LSTM only R2 to measure hierarchy contribution.

In [None]:
"""
NOTE: This cell is DEPRECATED.

The proper Baseline vs HAN comparison is now done in the training cell (Step 7)
and the comparison cell (Step 8b) using the actual PercePianoBiLSTMBaseline model.

The old "zeroed hierarchy" ablation approach was incorrect because it:
1. Used 4-layer split LSTMs instead of 7-layer single LSTM
2. Still included voice processing (which baseline doesn't have)
3. Fed 2048-dim (half zeros) to contractor instead of 512-dim dense

See the comparison cell above for the correct Baseline vs HAN comparison.
"""

print("="*70)
print("NOTE: Manual ablation cell deprecated")
print("="*70)
print("\nThe proper comparison is now done using:")
print("  1. PercePianoBiLSTMBaseline (7-layer LSTM matching VirtuosoNetSingle)")
print("  2. PercePianoVNetModule (HAN matching VirtuosoNetMultiLevel)")
print("\nSee the comparison results in the cells above.")
print("="*70)

## Step 9: Per-Fold Results Summary

In [None]:
import numpy as np
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

# Per-fold results for BOTH models
print("="*80)
print("PER-FOLD VALIDATION RESULTS")
print("="*80)

for model_name, trainer in [("Bi-LSTM Baseline", baseline_trainer), ("HAN", han_trainer)]:
    aggregate_metrics = trainer._compute_aggregate_metrics()
    
    print(f"\n{model_name}:")
    print(f"{'Fold':<6} {'Val R2':>10} {'Val Pearson':>12} {'Val MAE':>10} {'Val RMSE':>10} {'Epochs':>8} {'Time (s)':>10}")
    print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")

    for m in trainer.fold_metrics:
        print(f"{m.fold_id:<6} {m.val_r2:>+10.4f} {m.val_pearson:>+12.4f} {m.val_mae:>10.4f} {m.val_rmse:>10.4f} {m.epochs_trained:>8} {m.training_time_seconds:>10.1f}")

    print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")
    print(f"{'Mean':<6} {aggregate_metrics.mean_r2:>+10.4f} {aggregate_metrics.mean_pearson:>+12.4f} {aggregate_metrics.mean_mae:>10.4f} {aggregate_metrics.mean_rmse:>10.4f}")
    print(f"{'Std':<6} {aggregate_metrics.std_r2:>+10.4f} {aggregate_metrics.std_pearson:>+12.4f} {aggregate_metrics.std_mae:>10.4f} {aggregate_metrics.std_rmse:>10.4f}")

## Step 10: Per-Dimension Analysis

In [None]:
# Per-dimension analysis for HAN model (the one we care about for SOTA comparison)
han_agg = han_trainer._compute_aggregate_metrics()

print("="*80)
print("PER-DIMENSION R2 FOR HAN MODEL (Mean +/- Std across folds)")
print("="*80)

# Sort dimensions by mean R2
sorted_dims = sorted(
    han_agg.per_dim_mean_r2.items(),
    key=lambda x: x[1],
    reverse=True
)

print(f"\n{'Dimension':<25} {'Mean R2':>10} {'Std R2':>10} {'Status':<12}")
print(f"{'-'*25} {'-'*10} {'-'*10} {'-'*12}")

for dim, mean_r2 in sorted_dims:
    std_r2 = han_agg.per_dim_std_r2[dim]
    
    if mean_r2 >= 0.3:
        status = "[GOOD]"
    elif mean_r2 >= 0.1:
        status = "[OK]"
    elif mean_r2 >= 0:
        status = "[WEAK]"
    else:
        status = "[FAILED]"
    
    print(f"{dim:<25} {mean_r2:>+10.4f} {std_r2:>10.4f} {status:<12}")

# Summary
positive = sum(1 for d, r in sorted_dims if r > 0)
strong = sum(1 for d, r in sorted_dims if r >= 0.2)
n_dims = len(sorted_dims)

print(f"\nSummary: {positive}/{n_dims} positive R2, {strong}/{n_dims} strong (>= 0.2)")

## Step 11: Final Summary and Save

In [None]:
import json
import torch
from pathlib import Path

print("="*80)
print("FINAL SUMMARY - BASELINE VS HAN COMPARISON")
print("="*80)

# Get metrics from trainers
baseline_agg = baseline_trainer._compute_aggregate_metrics()
han_agg = han_trainer._compute_aggregate_metrics()

# Cross-validation results comparison
print(f"\n{'='*80}")
print("CROSS-VALIDATION RESULTS (FOLD 1)")
print(f"{'='*80}")

print(f"\n  {'Metric':<20} {'Baseline':>12} {'HAN':>12} {'Gain':>12} {'Expected':>12}")
print(f"  {'-'*20} {'-'*12} {'-'*12} {'-'*12} {'-'*12}")

baseline_cv_r2 = baseline_agg.mean_r2
han_cv_r2 = han_agg.mean_r2
cv_gain = han_cv_r2 - baseline_cv_r2

print(f"  {'R2':<20} {baseline_cv_r2:>+12.4f} {han_cv_r2:>+12.4f} {cv_gain:>+12.4f} {'+0.21':>12}")
print(f"  {'Pearson':<20} {baseline_agg.mean_pearson:>+12.4f} {han_agg.mean_pearson:>+12.4f}")
print(f"  {'MAE':<20} {baseline_agg.mean_mae:>12.4f} {han_agg.mean_mae:>12.4f}")
print(f"  {'RMSE':<20} {baseline_agg.mean_rmse:>12.4f} {han_agg.mean_rmse:>12.4f}")

# Test set results
print(f"\n{'='*80}")
print("TEST SET RESULTS (ENSEMBLE)")
print(f"{'='*80}")

if 'test_results' in dir() and test_results.get('baseline') and test_results.get('han'):
    baseline_test = test_results['baseline']['ensemble']
    han_test = test_results['han']['ensemble']
    test_gain = han_test['r2'] - baseline_test['r2']
    
    print(f"\n  {'Metric':<20} {'Baseline':>12} {'HAN':>12} {'Gain':>12}")
    print(f"  {'-'*20} {'-'*12} {'-'*12} {'-'*12}")
    print(f"  {'R2':<20} {baseline_test['r2']:>+12.4f} {han_test['r2']:>+12.4f} {test_gain:>+12.4f}")
    print(f"  {'Pearson':<20} {baseline_test['pearson']:>+12.4f} {han_test['pearson']:>+12.4f}")
    print(f"  {'MAE':<20} {baseline_test['mae']:>12.4f} {han_test['mae']:>12.4f}")
else:
    print("\n  [Test results not available - run test evaluation cell]")

# Comparison to PercePiano baselines
print(f"\n{'='*80}")
print("COMPARISON TO PERCEPIANO PAPER")
print(f"{'='*80}")

print(f"\n  {'Model':<30} {'Paper R2':>12} {'Our R2':>12} {'Match':>10}")
print(f"  {'-'*30} {'-'*12} {'-'*12} {'-'*10}")
print(f"  {'Bi-LSTM (VirtuosoNetSingle)':<30} {'0.185':>12} {baseline_cv_r2:>+12.4f}", end="")
if abs(baseline_cv_r2 - 0.185) < 0.05:
    print(f"{'[OK]':>10}")
elif baseline_cv_r2 > 0.135:
    print(f"{'[CLOSE]':>10}")
else:
    print(f"{'[LOW]':>10}")

print(f"  {'HAN (VirtuosoNetMultiLevel)':<30} {'0.397':>12} {han_cv_r2:>+12.4f}", end="")
if abs(han_cv_r2 - 0.397) < 0.05:
    print(f"{'[OK]':>10}")
elif han_cv_r2 > 0.30:
    print(f"{'[CLOSE]':>10}")
else:
    print(f"{'[LOW]':>10}")

print(f"  {'Hierarchy Gain':<30} {'+0.212':>12} {cv_gain:>+12.4f}", end="")
if cv_gain > 0.15:
    print(f"{'[GOOD]':>10}")
elif cv_gain > 0.05:
    print(f"{'[PARTIAL]':>10}")
else:
    print(f"{'[NONE]':>10}")

# Final interpretation
print(f"\n{'='*80}")
print("INTERPRETATION")
print(f"{'='*80}")

if baseline_cv_r2 < 0.10:
    print(f"\n  [CRITICAL] Baseline R2 ({baseline_cv_r2:.3f}) is very low!")
    print(f"  There is a fundamental issue with the data or training pipeline.")
    print(f"  Fix the baseline first before debugging HAN.")
elif cv_gain < 0.05 and baseline_cv_r2 >= 0.10:
    print(f"\n  [WARNING] Baseline works but hierarchy gain is low ({cv_gain:.3f})!")
    print(f"  The Bi-LSTM learns but HAN hierarchy is not helping.")
    print(f"  Debug: beat/measure attention, span_beat_to_note_num, index mapping.")
elif han_cv_r2 >= 0.35:
    print(f"\n  [SUCCESS] HAN R2 ({han_cv_r2:.3f}) approaching SOTA (0.397)!")
    print(f"  Model is ready for use or further tuning.")
elif han_cv_r2 >= 0.25:
    print(f"\n  [GOOD] HAN R2 ({han_cv_r2:.3f}) is usable for pseudo-labeling.")
    print(f"  Consider training all 4 folds and more epochs.")
else:
    print(f"\n  [NEEDS WORK] HAN R2 ({han_cv_r2:.3f}) needs improvement.")
    print(f"  Check hierarchy diagnostics for specific issues.")

print(f"\n{'='*80}")

In [None]:
# Final sync to Google Drive
print("="*60)
print("SYNC TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    print(f"\nSyncing all checkpoints and results...")
    
    # Sync baseline checkpoints
    print(f"\n[1] Syncing Baseline checkpoints...")
    subprocess.run(
        ['rclone', 'copy', str(baseline_trainer.checkpoint_dir), 
         f"{GDRIVE_CHECKPOINT_PATH}/percepiano_baseline", '--progress'],
        capture_output=False
    )
    
    # Sync HAN checkpoints
    print(f"\n[2] Syncing HAN checkpoints...")
    subprocess.run(
        ['rclone', 'copy', str(han_trainer.checkpoint_dir), 
         f"{GDRIVE_CHECKPOINT_PATH}/percepiano", '--progress'],
        capture_output=False
    )
    
    # Also sync fold assignments back to data directory
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print(f"\nSync complete!")
    print(f"  Baseline checkpoints: {GDRIVE_CHECKPOINT_PATH}/percepiano_baseline")
    print(f"  HAN checkpoints: {GDRIVE_CHECKPOINT_PATH}/percepiano")
    print(f"  Fold assignments: {GDRIVE_DATA_PATH}")
else:
    print(f"\nrclone not available - skipping sync")

print(f"\n{'='*60}")
print("TRAINING COMPLETE")
print(f"{'='*60}")