# PercePiano Replica Training (4-Fold Cross-Validation)

Train the PercePiano replica model using 4-fold piece-based cross-validation,
matching the methodology and hyperparameters from the PercePiano paper SOTA.

## Attribution

> **PercePiano: Piano Performance Evaluation Dataset with Multi-level Perceptual Features**  
> Park, Kim et al.  
> Nature Scientific Reports 2024  
> Paper: https://pmc.ncbi.nlm.nih.gov/articles/PMC11450231/  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Methodology

Following the exact approach from `m2pf_dataset_compositionfold.py`:

- **Piece-based splits**: All performances of the same piece stay in the same fold
- **Test set**: Select pieces randomly until reaching ~15% of SAMPLES (not pieces)
- **4-fold CV**: Remaining pieces distributed round-robin across folds
- **Per-fold normalization**: Stats computed from training folds only

## Hyperparameters (SOTA Configuration - R2 = 0.397)

These parameters match the published SOTA from `2_run_comp_multilevel_total.sh` and `han_bigger256_concat.yml`:

| Parameter | SOTA Value | Notes |
|-----------|------------|-------|
| input_size | 79 | SOTA uses 79 base features (includes section_tempo) |
| batch_size | 8 | From SOTA training script |
| learning_rate | 2.5e-5 | From SOTA training script |
| hidden_size | 256 | HAN encoder dimension |
| prediction_head | 512->512->19 | From model_m2pf.py (NOT config's final_fc_size) |
| dropout | 0.2 | Regularization |
| augment_train | False | SOTA doesn't use key augmentation |
| max_epochs | 200 | Extended training window |
| early_stopping_patience | 40 | More patience for convergence |
| gradient_clip_val | 2.0 | From parser.py |
| **precision** | **32** | **FP32 (original uses FP32, not mixed precision)** |
| **max_notes/slice_len** | **5000** | **SOTA slice size for overlapping sampling** |

## Key Fixes Applied (Round 8 - 2025-12-26)

### CRITICAL: Slice Sampling (Round 8)

The single most impactful discrepancy identified:

| Aspect | Original PercePiano | Previous Implementation | Impact |
|--------|---------------------|------------------------|--------|
| Samples/performance | 3-5 overlapping slices | 1 sample | 3-5x less training data |
| Training samples | ~600-1000 slices | ~200 samples | Critical for learning |
| Slice regeneration | Each epoch | None | No variation |

**Fix**: Added `make_slicing_indexes_by_measure()` and `SliceRegenerationCallback`.

### Round History

| Round | Changes | Result |
|-------|---------|--------|
| 1-5 | Various fixes (precision, attention, init) | R2 stuck around 0 |
| 6 | Match original architecture (no LayerNorm, 512->512, LR 2.5e-5) | Zero context_vector gradients |
| 7 | Fix data pipeline (79 features, PackedSequence) | R2 = 0.0017 (prediction collapse) |
| **8** | **Add SLICE SAMPLING (3-5 slices/sample, epoch regeneration)** | **Pending** |

See `docs/EXPERIMENT_LOG.md` for full investigation details.

## Expected Results

- Target R2: 0.35-0.40 (matching published SOTA of 0.397)
- Training time: ~8-12 hours on T4, ~3-5 hours on A100 (all 4 folds)
- With slice sampling: ~600-1000 training slices (vs ~200 samples before)

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Clone original PercePiano for comparison (needed for data diagnostics)
PERCEPIANO_PATH = '/tmp/crescendai/model/data/raw/PercePiano'
if not os.path.exists(PERCEPIANO_PATH):
    print("\nCloning original PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {PERCEPIANO_PATH}
else:
    print(f"\nPercePiano already present at {PERCEPIANO_PATH}")

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

# Training control
RESTART_TRAINING = True  # Set to True to clear checkpoints and start fresh

print("="*60)
print("PERCEPIANO REPLICA TRAINING (4-FOLD CV)")
print("="*60)

# Clear checkpoints if restarting
if RESTART_TRAINING and CHECKPOINT_ROOT.exists():
    print(f"\nRESTART_TRAINING=True: Clearing checkpoints at {CHECKPOINT_ROOT}")
    shutil.rmtree(CHECKPOINT_ROOT)
    print("  Checkpoints cleared!")

if RESTART_TRAINING and LOG_ROOT.exists():
    print(f"RESTART_TRAINING=True: Clearing logs at {LOG_ROOT}")
    shutil.rmtree(LOG_ROOT)
    print("  Logs cleared!")

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("\nrclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("\nrclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")
print(f"Log directory: {LOG_ROOT}")
print(f"\nRESTART_TRAINING: {RESTART_TRAINING}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

fold_file = DATA_ROOT / 'fold_assignments.json'
print(f"  fold_assignments.json: {'present' if fold_file.exists() else 'will be created'}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
TEST_RATIO = 0.15
SEED = 42

print("="*60)
print("FOLD ASSIGNMENT CREATION")
print("="*60)

# Force regeneration to use corrected methodology
# - Test set: select pieces until ~15% of SAMPLES (PercePiano methodology)
# - CV folds: greedy bin-packing for balanced sample counts (improvement over round-robin)
FORCE_REGENERATE = True

if FOLD_FILE.exists() and not FORCE_REGENERATE:
    print(f"\nLoading existing fold assignments from {FOLD_FILE}")
    fold_assignments = load_fold_assignments(FOLD_FILE)
else:
    if FOLD_FILE.exists():
        print(f"\nRemoving old fold assignments (regenerating with balanced methodology)...")
        FOLD_FILE.unlink()
    
    print(f"\nCreating new {N_FOLDS}-fold piece-based splits...")
    print("  Test set: select pieces until ~15% of SAMPLES")
    print("  CV folds: greedy bin-packing for balanced sample counts")
    fold_assignments = create_piece_based_folds(
        data_dir=DATA_ROOT,
        n_folds=N_FOLDS,
        test_ratio=TEST_RATIO,
        seed=SEED,
    )
    save_fold_assignments(fold_assignments, FOLD_FILE)

# Print statistics
print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Enable better CUDA error reporting
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

CONFIG = {
    # K-Fold settings
    'n_folds': N_FOLDS,
    'test_ratio': TEST_RATIO,
    # Data
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    'input_size': 79,
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'learning_rate': 2.5e-5,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 8,
    'max_epochs': 200,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '32',
    'max_notes': 5000,
    'slice_len': 5000,
    'num_workers': 4,
    'augment_train': False,
}

print("="*60)
print("TRAINING CONFIGURATION (SOTA - ROUND 8: SLICE SAMPLING)")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

# Print training dynamics info
print(f"  LR after decay: {CONFIG['learning_rate'] * 0.98:.2e}")

## Step 5b: Pre-Training Data Diagnostics (CRITICAL)

Run this BEFORE training to validate data and detect index issues that can break hierarchy.

In [None]:
"""
PRE-TRAINING DATA DIAGNOSTICS

This cell validates the data pipeline before training to catch issues that
would cause the hierarchical components (beat/measure attention) to fail.

Key checks:
1. Index format (should start from 1 after densification)
2. Zero-shifted values (should have no negatives)
3. Boundary detection (== 1 vs > 0 equivalence)
4. Slice statistics
"""

import torch
import numpy as np
from torch.utils.data import DataLoader
from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.training.diagnostics import analyze_indices

print("=" * 70)
print("PRE-TRAINING DATA DIAGNOSTICS")
print("=" * 70)

# Create a test dataset for fold 0
test_ds = PercePianoKFoldDataset(
    data_dir=DATA_ROOT,
    fold_assignments=fold_assignments,
    fold_id=0,
    mode="train",
    max_notes=CONFIG['max_notes'],
    slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
)

# Create a simple dataloader (no packing for easier analysis)
simple_loader = DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=0)

# Get one batch
batch = next(iter(simple_loader))

print("\n[1] BATCH SHAPE ANALYSIS:")
print(f"  input_features: {batch['input_features'].shape}")
print(f"  note_locations_beat: {batch['note_locations_beat'].shape}")
print(f"  note_locations_measure: {batch['note_locations_measure'].shape}")
print(f"  note_locations_voice: {batch['note_locations_voice'].shape}")
print(f"  scores: {batch['scores'].shape}")
print(f"  num_notes: {[batch['num_notes'][i].item() if hasattr(batch['num_notes'][i], 'item') else batch['num_notes'][i] for i in range(4)]}")

print("\n[2] INDEX ANALYSIS:")
idx_stats = analyze_indices(
    batch['note_locations_beat'],
    batch['note_locations_measure'],
)

print(f"  Beat indices: min={idx_stats['beat_min']}, max={idx_stats['beat_max']}")
print(f"  Measure indices: min={idx_stats['measure_min']}, max={idx_stats['measure_max']}")

# Check if indices start from 1 (required for hierarchy_utils)
if idx_stats['beat_min'] == 1:
    print(f"  [OK] Beat indices start from 1 (required)")
elif idx_stats['beat_min'] == 0:
    print(f"  [WARNING] Beat indices start from 0 (should be 1)")
else:
    print(f"  [ERROR] Beat indices start from {idx_stats['beat_min']} (unexpected)")

print("\n[3] ZERO-SHIFTED INDEX CHECK:")
print(f"  Beat zero-shifted range: [{idx_stats['beat_zero_shifted_min']}, {idx_stats['beat_zero_shifted_max']}]")
print(f"  Measure zero-shifted range: [{idx_stats['measure_zero_shifted_min']}, {idx_stats['measure_zero_shifted_max']}]")

if idx_stats['negative_beat_count'] > 0:
    print(f"  [CRITICAL ERROR] {idx_stats['negative_beat_count']} negative zero-shifted beat values!")
    print(f"    This WILL break span_beat_to_note_num!")
else:
    print(f"  [OK] No negative zero-shifted values")

if idx_stats['negative_measure_count'] > 0:
    print(f"  [WARNING] {idx_stats['negative_measure_count']} negative zero-shifted measure values")

print("\n[4] BOUNDARY DETECTION CHECK:")
print(f"  diff == 1 count: {idx_stats['diff_equals_1_count']}")
print(f"  diff > 0 count: {idx_stats['diff_greater_0_count']}")
print(f"  diff < 0 count: {idx_stats['diff_less_0_count']} (should only be at padding boundary)")

if idx_stats['diff_equals_1_count'] == idx_stats['diff_greater_0_count']:
    print(f"  [OK] == 1 and > 0 are equivalent (indices are sequential)")
else:
    print(f"  [WARNING] == 1 ({idx_stats['diff_equals_1_count']}) != > 0 ({idx_stats['diff_greater_0_count']})")
    print(f"    Indices may not be properly densified!")

if idx_stats['non_sequential_samples'] > 0:
    print(f"  [WARNING] {idx_stats['non_sequential_samples']}/4 samples have non-sequential indices")

print("\n[5] SLICE STATISTICS:")
print(f"  Total slices in dataset: {len(test_ds)}")
print(f"  Underlying samples: {len(test_ds.sample_files)}")
print(f"  Avg slices per sample: {len(test_ds) / len(test_ds.sample_files):.1f}")

print("\n[6] FIRST SAMPLE BEAT INDICES (first 50 values):")
first_beat = batch['note_locations_beat'][0].numpy()
first_num_notes = batch['num_notes'][0] if isinstance(batch['num_notes'][0], int) else batch['num_notes'][0].item()
print(f"  {first_beat[:min(50, first_num_notes)].tolist()}")

# Check for unique beats
unique_beats = np.unique(first_beat[:first_num_notes])
print(f"  Unique beat values: {len(unique_beats)}")
print(f"  First 10 unique beats: {unique_beats[:10].tolist()}")

print("\n[7] INPUT FEATURE STATISTICS:")
features = batch['input_features']
print(f"  Overall: mean={features.mean():.4f}, std={features.std():.4f}")
print(f"  Range: [{features.min():.4f}, {features.max():.4f}]")

# Check for NaN/Inf
nan_count = torch.isnan(features).sum().item()
inf_count = torch.isinf(features).sum().item()
if nan_count > 0:
    print(f"  [ERROR] {nan_count} NaN values detected!")
if inf_count > 0:
    print(f"  [ERROR] {inf_count} Inf values detected!")
if nan_count == 0 and inf_count == 0:
    print(f"  [OK] No NaN or Inf values")

print("\n[8] LABEL STATISTICS:")
scores = batch['scores']
print(f"  Mean per dimension: {scores.mean(dim=0).tolist()[:5]} ... (first 5)")
print(f"  Std per dimension: {scores.std(dim=0).tolist()[:5]} ... (first 5)")
print(f"  Range: [{scores.min():.4f}, {scores.max():.4f}]")

if scores.min() < 0 or scores.max() > 1:
    print(f"  [WARNING] Labels outside [0, 1] range!")
else:
    print(f"  [OK] Labels in [0, 1] range")

print("\n" + "=" * 70)
print("DIAGNOSTIC SUMMARY")
print("=" * 70)

issues = []
if idx_stats['beat_min'] != 1:
    issues.append("Beat indices don't start from 1")
if idx_stats['negative_beat_count'] > 0:
    issues.append(f"Negative zero-shifted values ({idx_stats['negative_beat_count']})")
if idx_stats['diff_equals_1_count'] != idx_stats['diff_greater_0_count']:
    issues.append("Non-sequential indices detected")
if nan_count > 0 or inf_count > 0:
    issues.append("NaN/Inf in features")

if issues:
    print(f"\n[ISSUES FOUND] {len(issues)} issues that may affect training:")
    for i, issue in enumerate(issues, 1):
        print(f"  {i}. {issue}")
    print("\nConsider fixing these before training!")
else:
    print("\n[ALL CHECKS PASSED] Data pipeline looks correct!")
    print("Proceed to training.")

print("=" * 70)

## Step 6: Initialize K-Fold Trainer

In [None]:
from src.percepiano.training.kfold_trainer import KFoldTrainer
import pytorch_lightning as pl

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Create K-Fold trainer
kfold_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
)

print("K-Fold Trainer initialized")
print(f"  Folds: {N_FOLDS}")
print(f"  Checkpoints: {CHECKPOINT_ROOT}")
print(f"  Logs: {LOG_ROOT}")

## Step 7: Train All Folds

In [None]:
print("="*60)
print("TRAINING SINGLE FOLD (DIAGNOSTIC RUN)")
print("="*60)
print("\nPercePiano SOTA baselines:")
print("  Bi-LSTM: R2 = 0.185")
print("  MidiBERT: R2 = 0.313")
print("  Bi-LSTM + SA + HAN: R2 = 0.397 (SOTA)")
print("="*60)

# Train single fold for diagnostic purposes
FOLD_TO_TRAIN = 1  # Fold 1 is typically well-balanced

fold_metrics = kfold_trainer.train_fold(
    fold_id=FOLD_TO_TRAIN,
    verbose=True,
    resume_from_checkpoint=False,
)

# Save results
kfold_trainer.save_results()

# IMPORTANT: Store trained model for diagnostics (Round 12 fix)
# This avoids needing to load from checkpoint in subsequent cells
trained_model = kfold_trainer.get_trained_model(FOLD_TO_TRAIN)
if trained_model is not None:
    print(f"\n  Trained model stored in memory for diagnostics")
else:
    print(f"\n  [WARNING] Could not retrieve trained model - will need checkpoint")

print("\n" + "="*60)
print(f"FOLD {FOLD_TO_TRAIN} TRAINING COMPLETE")
print("="*60)
print(f"  Val R2: {fold_metrics.val_r2:+.4f}")
print(f"  Val Pearson: {fold_metrics.val_pearson:+.4f}")
print(f"  Epochs: {fold_metrics.epochs_trained}")
print("="*60)

In [None]:
# Sync checkpoints to Google Drive after training
if RCLONE_AVAILABLE:
    print("Syncing checkpoints to Google Drive...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    print("Sync complete!")

## Step 8: Test Set Evaluation

In [None]:
# Evaluate all fold models on held-out test set
test_results = kfold_trainer.evaluate_on_test(verbose=True)

## Step 8b: Post-Training Hierarchy Diagnostics

Analyze why the hierarchical components (beat/measure attention) may not be contributing.
This cell loads the best model from fold 1 and runs comprehensive diagnostics.

In [None]:
"""
POST-TRAINING HIERARCHY DIAGNOSTICS

This cell runs comprehensive diagnostics to understand why hierarchical 
components (beat/measure attention) may not be contributing to model performance.

Key metrics:
1. Activation variances at each hierarchy level
2. Attention entropy (collapsed vs distributed)
3. Contribution analysis (% variance from each component)
4. Ablation comparison (full model vs Bi-LSTM only)
"""

import torch
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from src.percepiano.models.percepiano_replica import PercePianoVNetModule
from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.training.diagnostics import (
    DiagnosticCallback,
    run_full_diagnostics,
    compute_attention_entropy,
)
from sklearn.metrics import r2_score

print("=" * 70)
print("POST-TRAINING HIERARCHY DIAGNOSTICS")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None

# PRIORITY 1: Use in-memory model from training (Round 12 fix)
if 'trained_model' in dir() and trained_model is not None:
    model = trained_model.to(device)
    model.eval()
    print(f"\nUsing in-memory trained model (no checkpoint loading needed)")
    print(f"Device: {device}")
else:
    # FALLBACK: Try loading from checkpoint
    print(f"\nNo in-memory model found - attempting checkpoint loading...")
    fold_dir = CHECKPOINT_ROOT / "fold_1"
    best_ckpts = list(fold_dir.glob("best-*.ckpt"))
    if not best_ckpts:
        print(f"No checkpoints found in {fold_dir}")
        print("Run Step 7 (training) first to create model.")
    else:
        best_ckpt = max(best_ckpts, key=lambda p: p.stat().st_mtime)
        print(f"Loading checkpoint: {best_ckpt}")
        
        model = PercePianoVNetModule.load_from_checkpoint(
            str(best_ckpt),
            input_size=CONFIG.get("input_size", 79),
            hidden_size=CONFIG.get("hidden_size", 256),
            note_layers=CONFIG.get("note_layers", 2),
            voice_layers=CONFIG.get("voice_layers", 2),
            beat_layers=CONFIG.get("beat_layers", 2),
            measure_layers=CONFIG.get("measure_layers", 1),
            num_attention_heads=CONFIG.get("num_attention_heads", 8),
            dropout=CONFIG.get("dropout", 0.2),
        )
        model = model.to(device)
        model.eval()
        print(f"Device: {device}")

if model is not None:
    # Create validation dataset
    val_ds = PercePianoKFoldDataset(
        data_dir=DATA_ROOT,
        fold_assignments=fold_assignments,
        fold_id=1,
        mode="val",
        max_notes=CONFIG['max_notes'],
        slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
    )
    val_loader = DataLoader(
        val_ds, 
        batch_size=4, 
        shuffle=False, 
        num_workers=0,
        collate_fn=percepiano_pack_collate,
    )
    print(f"Validation samples: {len(val_ds)}")
    
    # Run full diagnostics
    print("\nRunning full diagnostic analysis...")
    diag_results = run_full_diagnostics(model, val_loader, device, num_batches=5)
    
    # Print activation statistics
    print("\n" + "=" * 70)
    print("ACTIVATION VARIANCE ANALYSIS")
    print("=" * 70)
    
    print(f"\n{'Component':<25} {'Mean Std':>12} {'Min Std':>12} {'Max Std':>12} {'Status':<15}")
    print(f"{'-'*25} {'-'*12} {'-'*12} {'-'*12} {'-'*15}")
    
    key_activations = [
        ('input_std', 0.3, 0.6, 'Input features'),
        ('x_embedded_std', 0.15, 0.4, 'After embedding'),
        ('note_out_std', 0.1, 0.4, 'Note LSTM'),
        ('voice_out_std', 0.1, 0.4, 'Voice LSTM'),
        ('hidden_out_std', 0.1, 0.4, 'Hidden (note+voice)'),
        ('beat_nodes_std', 0.05, 0.3, 'Beat nodes'),
        ('beat_out_std', 0.05, 0.3, 'Beat LSTM'),
        ('beat_spanned_std', 0.05, 0.3, 'Beat spanned'),
        ('measure_nodes_std', 0.05, 0.3, 'Measure nodes'),
        ('measure_out_std', 0.05, 0.3, 'Measure LSTM'),
        ('measure_spanned_std', 0.05, 0.3, 'Measure spanned'),
        ('contracted_std', 0.05, 0.3, 'Contracted'),
        ('aggregated_std', 0.1, 0.5, 'Aggregated'),
        ('predictions_std', 0.1, 0.25, 'Predictions'),
    ]
    
    issues = []
    for key, low_thresh, high_thresh, name in key_activations:
        if key in diag_results['activation_stats']:
            stats = diag_results['activation_stats'][key]
            mean_std = stats['mean']
            min_std = stats['min']
            max_std = stats['max']
            
            if mean_std < low_thresh:
                status = "[LOW - ISSUE]"
                issues.append(f"{name}: std={mean_std:.4f} < {low_thresh}")
            elif mean_std > high_thresh:
                status = "[HIGH]"
            else:
                status = "[OK]"
            
            print(f"{name:<25} {mean_std:>12.4f} {min_std:>12.4f} {max_std:>12.4f} {status:<15}")
    
    # Contribution analysis
    print("\n" + "=" * 70)
    print("HIERARCHY CONTRIBUTION ANALYSIS")
    print("=" * 70)
    
    contribution_keys = [
        ('hidden_out_contribution', 'Bi-LSTM (note+voice)'),
        ('beat_spanned_contribution', 'Beat hierarchy'),
        ('measure_spanned_contribution', 'Measure hierarchy'),
    ]
    
    print(f"\nFraction of total variance from each component:")
    for key, name in contribution_keys:
        if key in diag_results['activation_stats']:
            contrib = diag_results['activation_stats'][key]['mean']
            bar = '#' * int(contrib * 50)
            print(f"  {name:<25} {contrib:>6.1%} |{bar}")
            
            if key == 'beat_spanned_contribution' and contrib < 0.10:
                issues.append(f"Beat hierarchy contributing only {contrib:.1%} (expected ~30%)")
            if key == 'measure_spanned_contribution' and contrib < 0.05:
                issues.append(f"Measure hierarchy contributing only {contrib:.1%} (expected ~10%)")
    
    # Index statistics
    print("\n" + "=" * 70)
    print("INDEX STATISTICS (HIERARCHY HEALTH)")
    print("=" * 70)
    
    idx_stats = diag_results['index_stats']
    print(f"\n  Negative zero-shifted beats: {idx_stats.get('negative_beat_count', {}).get('total', 'N/A')}")
    print(f"  Negative zero-shifted measures: {idx_stats.get('negative_measure_count', {}).get('total', 'N/A')}")
    
    if idx_stats.get('negative_beat_count', {}).get('total', 0) > 0:
        issues.append("Negative zero-shifted beat indices detected!")
    
    # Summary
    print("\n" + "=" * 70)
    print("DIAGNOSTIC SUMMARY")
    print("=" * 70)
    
    if issues:
        print(f"\n[{len(issues)} ISSUES DETECTED]")
        for i, issue in enumerate(issues, 1):
            print(f"  {i}. {issue}")
        
        print("\nLikely root causes:")
        if any('beat_spanned' in issue or 'Beat hierarchy' in issue for issue in issues):
            print("  - Beat attention may be collapsed (uniform weights)")
            print("  - span_beat_to_note_num may have index mapping issues")
        if any('measure_spanned' in issue or 'Measure hierarchy' in issue for issue in issues):
            print("  - Measure hierarchy not aggregating properly")
        if any('note_out' in issue.lower() or 'lstm' in issue.lower() for issue in issues):
            print("  - LSTM outputs have very low variance")
            print("  - Check PackedSequence handling")
    else:
        print("\n[ALL CHECKS PASSED]")
        print("Hierarchy components appear to be working correctly.")
        print("If R2 is still low, investigate:")
        print("  - Learning rate scheduling")
        print("  - Number of training epochs")
        print("  - Label alignment issues")
    
    print("=" * 70)
else:
    print("\n[SKIPPED] No model available for diagnostics.")
    print("Run Step 7 (training) first.")

## Step 8c: Manual Ablation Test

Directly compare full model R2 vs Bi-LSTM only R2 to measure hierarchy contribution.

In [None]:
"""
MANUAL ABLATION TEST

Compare full model performance vs Bi-LSTM only (zeroing out hierarchy).
This directly measures how much the beat/measure components contribute.

Expected results (from PercePiano paper):
- Bi-LSTM alone: R2 = 0.185
- + Score Alignment: R2 = 0.304 (+0.119)
- + HAN: R2 = 0.397 (+0.093)

If our ablation shows hierarchy_gain < 0.05, the hierarchy is broken.
"""

import torch
import numpy as np
from sklearn.metrics import r2_score
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

print("=" * 70)
print("MANUAL ABLATION TEST: Full Model vs Bi-LSTM Only")
print("=" * 70)

# Check if model is available (from previous cell or needs to be loaded)
if 'model' not in dir() or model is None:
    # Try to use in-memory trained model (Round 12 fix)
    if 'trained_model' in dir() and trained_model is not None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = trained_model.to(device)
        model.eval()
        print(f"\nUsing in-memory trained model")
        print(f"Device: {device}")
    else:
        print("Model not loaded. Run the previous diagnostic cell first,")
        print("or run Step 7 (training) to create a trained model.")
        model = None

if model is not None:
    model.eval()
    device = next(model.parameters()).device
    
    # Create val_loader if not already created
    if 'val_loader' not in dir():
        from torch.utils.data import DataLoader
        from src.percepiano.data.percepiano_vnet_dataset import (
            PercePianoKFoldDataset,
            percepiano_pack_collate,
        )
        val_ds = PercePianoKFoldDataset(
            data_dir=DATA_ROOT,
            fold_assignments=fold_assignments,
            fold_id=1,
            mode="val",
            max_notes=CONFIG['max_notes'],
            slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
        )
        val_loader = DataLoader(
            val_ds, 
            batch_size=4, 
            shuffle=False, 
            num_workers=0,
            collate_fn=percepiano_pack_collate,
        )
    
    # Collect predictions
    full_preds = []
    ablated_preds = []
    targets = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            if batch_idx >= 10:  # Limit for speed
                break
                
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            input_features = batch['input_features']
            note_locations = {
                'beat': batch['note_locations_beat'],
                'measure': batch['note_locations_measure'],
                'voice': batch['note_locations_voice'],
            }
            
            # Full model prediction
            outputs = model(input_features, note_locations)
            full_preds.append(outputs['predictions'].cpu())
            targets.append(batch['scores'].cpu())
            
            # Ablated prediction (Bi-LSTM only)
            # We need to manually forward through the model zeroing out hierarchy
            
            # Handle PackedSequence
            if isinstance(input_features, PackedSequence):
                x_padded, lengths = pad_packed_sequence(input_features, batch_first=True)
            else:
                x_padded = input_features
                lengths = batch.get('lengths', torch.tensor([x_padded.shape[1]] * x_padded.shape[0]))
            
            han = model.han_encoder
            
            # Project input
            x_embedded = han.note_fc(x_padded)
            
            # Compute actual lengths
            from src.percepiano.models.hierarchy_utils import compute_actual_lengths
            actual_lengths = compute_actual_lengths(note_locations['beat'])
            
            # Note LSTM
            x_packed = pack_padded_sequence(
                x_embedded,
                actual_lengths.cpu().clamp(min=1),
                batch_first=True,
                enforce_sorted=False,
            )
            note_out, _ = han.note_lstm(x_packed)
            note_out, _ = pad_packed_sequence(note_out, batch_first=True, total_length=x_embedded.shape[1])
            
            # Voice processing
            voice_out = han._run_voice_processing(x_embedded, note_locations['voice'], actual_lengths)
            
            # Combined hidden (Bi-LSTM output)
            hidden_out = torch.cat([note_out, voice_out], dim=-1)
            
            # Zero out hierarchy (ablation)
            batch_size, seq_len = hidden_out.shape[:2]
            beat_spanned = torch.zeros(batch_size, seq_len, han.beat_size * 2, device=device)
            measure_spanned = torch.zeros(batch_size, seq_len, han.measure_size * 2, device=device)
            
            # Concatenate with zeroed hierarchy
            total_note_cat = torch.cat([hidden_out, beat_spanned, measure_spanned], dim=-1)
            
            # Performance contractor
            contracted = model.performance_contractor(total_note_cat)
            
            # Final attention
            attention_mask = note_locations['beat'] > 0
            aggregated = model.final_attention(contracted, mask=attention_mask)
            
            # Prediction head
            logits = model.prediction_head(aggregated)
            ablated_pred = torch.sigmoid(logits)
            ablated_preds.append(ablated_pred.cpu())
    
    # Compute R2 scores
    full_preds = torch.cat(full_preds).numpy()
    ablated_preds = torch.cat(ablated_preds).numpy()
    targets = torch.cat(targets).numpy()
    
    full_r2 = r2_score(targets, full_preds)
    ablated_r2 = r2_score(targets, ablated_preds)
    hierarchy_gain = full_r2 - ablated_r2
    
    print(f"\nResults on {len(targets)} validation samples:")
    print(f"\n  {'Model':<25} {'R2':>10}")
    print(f"  {'-'*25} {'-'*10}")
    print(f"  {'Full Model':<25} {full_r2:>+10.4f}")
    print(f"  {'Bi-LSTM Only (ablated)':<25} {ablated_r2:>+10.4f}")
    print(f"  {'-'*25} {'-'*10}")
    print(f"  {'Hierarchy Gain':<25} {hierarchy_gain:>+10.4f}")
    
    print(f"\n" + "=" * 70)
    print("INTERPRETATION")
    print("=" * 70)
    
    print(f"\nExpected hierarchy gains (from PercePiano paper):")
    print(f"  Score Alignment: +0.119 R2")
    print(f"  HAN:             +0.093 R2")
    print(f"  Total:           +0.212 R2")
    
    print(f"\nOur hierarchy gain: {hierarchy_gain:+.4f} R2")
    
    if hierarchy_gain < 0.01:
        print(f"\n[CRITICAL] Hierarchy is NOT contributing!")
        print(f"  The beat/measure components add < 0.01 R2")
        print(f"  This confirms the hierarchical processing is broken.")
        print(f"\n  Possible causes:")
        print(f"  1. span_beat_to_note_num mapping is incorrect")
        print(f"  2. Beat/measure attention weights collapsed to uniform")
        print(f"  3. Zero-shifted indices have negative values")
        print(f"  4. LSTM outputs have near-zero variance")
    elif hierarchy_gain < 0.05:
        print(f"\n[WARNING] Hierarchy contributing only {hierarchy_gain:.3f} R2")
        print(f"  Expected ~0.21, getting < 0.05")
        print(f"  Hierarchy is partially working but underperforming.")
    elif hierarchy_gain < 0.15:
        print(f"\n[PARTIAL] Hierarchy contributing {hierarchy_gain:.3f} R2")
        print(f"  Better than nothing, but expected ~0.21")
        print(f"  Some component may still have issues.")
    else:
        print(f"\n[GOOD] Hierarchy contributing {hierarchy_gain:.3f} R2")
        print(f"  This is in the expected range (~0.21).")
        print(f"  If overall R2 is still low, investigate other factors.")
    
    print("=" * 70)
else:
    print("\n[SKIPPED] No model available for ablation test.")

## Step 9: Per-Fold Results Summary

In [None]:
import numpy as np
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

# Compute aggregate metrics from trainer (in case cell was run separately)
aggregate_metrics = kfold_trainer._compute_aggregate_metrics()

print("="*80)
print("PER-FOLD VALIDATION RESULTS")
print("="*80)

print(f"\n{'Fold':<6} {'Val R2':>10} {'Val Pearson':>12} {'Val MAE':>10} {'Val RMSE':>10} {'Epochs':>8} {'Time (s)':>10}")
print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")

for m in kfold_trainer.fold_metrics:
    print(f"{m.fold_id:<6} {m.val_r2:>+10.4f} {m.val_pearson:>+12.4f} {m.val_mae:>10.4f} {m.val_rmse:>10.4f} {m.epochs_trained:>8} {m.training_time_seconds:>10.1f}")

print(f"{'-'*6} {'-'*10} {'-'*12} {'-'*10} {'-'*10} {'-'*8} {'-'*10}")
print(f"{'Mean':<6} {aggregate_metrics.mean_r2:>+10.4f} {aggregate_metrics.mean_pearson:>+12.4f} {aggregate_metrics.mean_mae:>10.4f} {aggregate_metrics.mean_rmse:>10.4f}")
print(f"{'Std':<6} {aggregate_metrics.std_r2:>+10.4f} {aggregate_metrics.std_pearson:>+12.4f} {aggregate_metrics.std_mae:>10.4f} {aggregate_metrics.std_rmse:>10.4f}")

## Step 10: Per-Dimension Analysis

In [None]:
# Ensure aggregate_metrics is available (in case cell was run separately)
if 'aggregate_metrics' not in dir():
    aggregate_metrics = kfold_trainer._compute_aggregate_metrics()

print("="*80)
print("PER-DIMENSION R2 (Mean +/- Std across folds)")
print("="*80)

# Sort dimensions by mean R2
sorted_dims = sorted(
    aggregate_metrics.per_dim_mean_r2.items(),
    key=lambda x: x[1],
    reverse=True
)

print(f"\n{'Dimension':<25} {'Mean R2':>10} {'Std R2':>10} {'Status':<12}")
print(f"{'-'*25} {'-'*10} {'-'*10} {'-'*12}")

for dim, mean_r2 in sorted_dims:
    std_r2 = aggregate_metrics.per_dim_std_r2[dim]
    
    if mean_r2 >= 0.3:
        status = "[GOOD]"
    elif mean_r2 >= 0.1:
        status = "[OK]"
    elif mean_r2 >= 0:
        status = "[WEAK]"
    else:
        status = "[FAILED]"
    
    print(f"{dim:<25} {mean_r2:>+10.4f} {std_r2:>10.4f} {status:<12}")

# Summary
positive = sum(1 for d, r in sorted_dims if r > 0)
strong = sum(1 for d, r in sorted_dims if r >= 0.2)
n_dims = len(sorted_dims)

print(f"\nSummary: {positive}/{n_dims} positive R2, {strong}/{n_dims} strong (>= 0.2)")

## Step 11: Final Summary and Save

In [None]:
import json
import torch
from pathlib import Path

# Ensure aggregate_metrics is available (in case cell was run separately)
if 'aggregate_metrics' not in dir():
    aggregate_metrics = kfold_trainer._compute_aggregate_metrics()

print("="*80)
print("FINAL SUMMARY")
print("="*80)

# Cross-validation results
print(f"\n4-Fold Cross-Validation Results:")
print(f"  Mean R2:       {aggregate_metrics.mean_r2:.4f} +/- {aggregate_metrics.std_r2:.4f}")
print(f"  Mean Pearson:  {aggregate_metrics.mean_pearson:.4f} +/- {aggregate_metrics.std_pearson:.4f}")
print(f"  Mean Spearman: {aggregate_metrics.mean_spearman:.4f} +/- {aggregate_metrics.std_spearman:.4f}")
print(f"  Mean MAE:      {aggregate_metrics.mean_mae:.4f} +/- {aggregate_metrics.std_mae:.4f}")
print(f"  Mean RMSE:     {aggregate_metrics.mean_rmse:.4f} +/- {aggregate_metrics.std_rmse:.4f}")
print(f"  Training time: {aggregate_metrics.total_training_time/60:.1f} minutes")

# Test set results
print(f"\nTest Set (Ensemble of 4 models):")
print(f"  R2:       {test_results['ensemble']['r2']:.4f}")
print(f"  Pearson:  {test_results['ensemble']['pearson']:.4f}")
print(f"  Spearman: {test_results['ensemble']['spearman']:.4f}")
print(f"  MAE:      {test_results['ensemble']['mae']:.4f}")
print(f"  RMSE:     {test_results['ensemble']['rmse']:.4f}")

# Comparison to baselines
print(f"\nComparison to PercePiano baselines:")
print(f"  Bi-LSTM:      R2 = 0.185")
print(f"  MidiBERT:     R2 = 0.313")
print(f"  HAN SOTA:     R2 = 0.397")
print(f"  Ours (CV):    R2 = {aggregate_metrics.mean_r2:.3f} +/- {aggregate_metrics.std_r2:.3f}")
print(f"  Ours (Test):  R2 = {test_results['ensemble']['r2']:.3f}")

# Interpretation
cv_r2 = aggregate_metrics.mean_r2
test_r2 = test_results['ensemble']['r2']

print(f"\nInterpretation:")
if cv_r2 >= 0.35:
    print(f"  [EXCELLENT] CV R2 >= 0.35 matches published SOTA!")
elif cv_r2 >= 0.25:
    print(f"  [GOOD] CV R2 >= 0.25 is usable for pseudo-labeling")
elif cv_r2 >= 0.10:
    print(f"  [FAIR] CV R2 >= 0.10 shows learning, needs improvement")
else:
    print(f"  [NEEDS WORK] CV R2 < 0.10, significant improvement needed")

# Save ensemble model if good enough
if cv_r2 >= 0.25:
    print(f"\nModel qualifies for pseudo-labeling MAESTRO!")

In [None]:
# Final sync to Google Drive
print("="*60)
print("SYNC TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    print(f"\nSyncing all checkpoints and results...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    
    # Also sync fold assignments back to data directory
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print(f"\nSync complete!")
    print(f"  Checkpoints: {GDRIVE_CHECKPOINT_PATH}")
    print(f"  Fold assignments: {GDRIVE_DATA_PATH}")
else:
    print(f"\nrclone not available - skipping sync")

print(f"\n{'='*60}")
print("TRAINING COMPLETE")
print(f"{'='*60}")