# PercePiano Replica Training (4-Fold Cross-Validation)

Train the PercePiano replica model using 4-fold piece-based cross-validation,
matching the methodology and hyperparameters from the PercePiano paper SOTA.

## Attribution

> **PercePiano: Piano Performance Evaluation Dataset with Multi-level Perceptual Features**  
> Park, Kim et al.  
> Nature Scientific Reports 2024  
> Paper: https://pmc.ncbi.nlm.nih.gov/articles/PMC11450231/  
> GitHub: https://github.com/JonghoKimSNU/PercePiano

## Methodology

Following the exact approach from `m2pf_dataset_compositionfold.py`:

- **Piece-based splits**: All performances of the same piece stay in the same fold
- **Test set**: Select pieces randomly until reaching ~15% of SAMPLES (not pieces)
- **4-fold CV**: Remaining pieces distributed round-robin across folds
- **Per-fold normalization**: Stats computed from training folds only

## Hyperparameters (SOTA Configuration - R2 = 0.397)

These parameters match the published SOTA from `2_run_comp_multilevel_total.sh` and `han_bigger256_concat.yml`:

| Parameter | SOTA Value | Notes |
|-----------|------------|-------|
| input_size | 79 | SOTA uses 79 base features (includes section_tempo) |
| batch_size | 8 | From SOTA training script |
| learning_rate | 2.5e-5 | From SOTA training script |
| hidden_size | 256 | HAN encoder dimension |
| prediction_head | 512->512->19 | From model_m2pf.py (NOT config's final_fc_size) |
| dropout | 0.2 | Regularization |
| augment_train | False | SOTA doesn't use key augmentation |
| max_epochs | 200 | Extended training window |
| early_stopping_patience | 40 | More patience for convergence |
| gradient_clip_val | 2.0 | From parser.py |
| **precision** | **32** | **FP32 (original uses FP32, not mixed precision)** |
| **max_notes/slice_len** | **5000** | **SOTA slice size for overlapping sampling** |

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Clone original PercePiano for comparison (needed for data diagnostics)
PERCEPIANO_PATH = '/tmp/crescendai/model/data/raw/PercePiano'
if not os.path.exists(PERCEPIANO_PATH):
    print("\nCloning original PercePiano repository...")
    !git clone https://github.com/JonghoKimSNU/PercePiano.git {PERCEPIANO_PATH}
else:
    print(f"\nPercePiano already present at {PERCEPIANO_PATH}")

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Configure Paths and Check rclone

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

# Training control
RESTART_TRAINING = True  # Set to True to clear checkpoints and start fresh

print("="*60)
print("PERCEPIANO REPLICA TRAINING (4-FOLD CV)")
print("="*60)

# Clear checkpoints if restarting
if RESTART_TRAINING and CHECKPOINT_ROOT.exists():
    print(f"\nRESTART_TRAINING=True: Clearing checkpoints at {CHECKPOINT_ROOT}")
    shutil.rmtree(CHECKPOINT_ROOT)
    print("  Checkpoints cleared!")

if RESTART_TRAINING and LOG_ROOT.exists():
    print(f"RESTART_TRAINING=True: Clearing logs at {LOG_ROOT}")
    shutil.rmtree(LOG_ROOT)
    print("  Logs cleared!")

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("\nrclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
else:
    print("\nrclone 'gdrive' remote: NOT CONFIGURED")
    print("Run 'rclone config' in terminal to set up Google Drive")
    RCLONE_AVAILABLE = False

print(f"\nData directory: {DATA_ROOT}")
print(f"Checkpoint directory: {CHECKPOINT_ROOT}")
print(f"Log directory: {LOG_ROOT}")
print(f"\nRESTART_TRAINING: {RESTART_TRAINING}")

## Step 3: Download Data from Google Drive

In [None]:
import subprocess

if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

# Download preprocessed data
print("Downloading preprocessed VirtuosoNet features from Google Drive...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify data
print("\n" + "="*60)
print("DATA VERIFICATION")
print("="*60)

total_samples = 0
for split in ['train', 'val', 'test']:
    split_dir = DATA_ROOT / split
    if split_dir.exists():
        count = len(list(split_dir.glob('*.pkl')))
        total_samples += count
        print(f"  {split}: {count} samples")
    else:
        print(f"  {split}: MISSING!")

print(f"  Total: {total_samples} samples")

stat_file = DATA_ROOT / 'stat.pkl'
print(f"  stat.pkl: {'present' if stat_file.exists() else 'MISSING!'}")

fold_file = DATA_ROOT / 'fold_assignments.json'
print(f"  fold_assignments.json: {'present' if fold_file.exists() else 'will be created'}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
TEST_RATIO = 0.15
SEED = 42

print("="*60)
print("FOLD ASSIGNMENT CREATION")
print("="*60)

# Force regeneration to use corrected methodology
# - Test set: select pieces until ~15% of SAMPLES (PercePiano methodology)
# - CV folds: greedy bin-packing for balanced sample counts (improvement over round-robin)
FORCE_REGENERATE = True

if FOLD_FILE.exists() and not FORCE_REGENERATE:
    print(f"\nLoading existing fold assignments from {FOLD_FILE}")
    fold_assignments = load_fold_assignments(FOLD_FILE)
else:
    if FOLD_FILE.exists():
        print(f"\nRemoving old fold assignments (regenerating with balanced methodology)...")
        FOLD_FILE.unlink()
    
    print(f"\nCreating new {N_FOLDS}-fold piece-based splits...")
    print("  Test set: select pieces until ~15% of SAMPLES")
    print("  CV folds: greedy bin-packing for balanced sample counts")
    fold_assignments = create_piece_based_folds(
        data_dir=DATA_ROOT,
        n_folds=N_FOLDS,
        test_ratio=TEST_RATIO,
        seed=SEED,
    )
    save_fold_assignments(fold_assignments, FOLD_FILE)

# Print statistics
print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

# Enable better CUDA error reporting
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import model type constants
from src.percepiano.training.kfold_trainer import MODEL_TYPE_HAN, MODEL_TYPE_BASELINE

CONFIG = {
    # K-Fold settings
    'n_folds': N_FOLDS,
    'test_ratio': TEST_RATIO,
    # Data
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    'input_size': 79,
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'learning_rate': 2.5e-5,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 8,
    'max_epochs': 200,
    'early_stopping_patience': 20,
    'gradient_clip_val': 2.0,
    'precision': '32',
    'max_notes': 5000,
    'slice_len': 5000,
    'num_workers': 4,
    'augment_train': False,
}

print("="*60)
print("TRAINING CONFIGURATION (SOTA - ROUND 13)")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

print(f"\nModel types available:")
print(f"  MODEL_TYPE_BASELINE = '{MODEL_TYPE_BASELINE}' (7-layer Bi-LSTM, expected R2 ~0.19)")
print(f"  MODEL_TYPE_HAN = '{MODEL_TYPE_HAN}' (Hierarchical, expected R2 ~0.40)")

## Step 6: Initialize Phase 2 Trainers

Create trainers for all 3 incremental models:
1. Baseline (7-layer Bi-LSTM)
2. Baseline + Beat hierarchy
3. Baseline + Beat + Measure hierarchy

In [None]:
from src.percepiano.training.kfold_trainer import (
    KFoldTrainer,
    MODEL_TYPE_BASELINE,
    MODEL_TYPE_BASELINE_BEAT,
    MODEL_TYPE_BASELINE_BEAT_MEASURE,
)
import pytorch_lightning as pl

# Set seed for reproducibility
pl.seed_everything(42, workers=True)

# Phase 2: Train 3 incremental models on Fold 2
# 1. Baseline (7-layer Bi-LSTM) - R2 ~0.19 (Phase 1 validated)
# 2. Baseline + Beat - R2 ~0.25-0.30 (expected +0.10 gain)
# 3. Baseline + Beat + Measure - R2 ~0.35-0.40 (expected +0.05-0.10 more)

FOLD_ID = 2  # Best performing fold with longest pieces (1-54 beats)

print("="*60)
print("PHASE 2: INITIALIZE INCREMENTAL TRAINERS")
print("="*60)
print(f"\nTraining Fold: {FOLD_ID}")
print("\nModels to train:")
print("  1. Baseline (7-layer Bi-LSTM) - expected R2 ~0.19")
print("  2. Baseline + Beat hierarchy - expected R2 ~0.25-0.30")
print("  3. Baseline + Beat + Measure - expected R2 ~0.35-0.40")

# 1. Baseline trainer (7-layer Bi-LSTM)
print("\n[1] Bi-LSTM Baseline Trainer:")
baseline_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_BASELINE,
)

# 2. Baseline + Beat trainer
print("\n[2] Baseline + Beat Trainer:")
beat_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_BASELINE_BEAT,
)

# 3. Baseline + Beat + Measure trainer
print("\n[3] Baseline + Beat + Measure Trainer:")
beat_measure_trainer = KFoldTrainer(
    config=CONFIG,
    fold_assignments=fold_assignments,
    data_dir=DATA_ROOT,
    checkpoint_dir=CHECKPOINT_ROOT,
    log_dir=LOG_ROOT,
    n_folds=N_FOLDS,
    model_type=MODEL_TYPE_BASELINE_BEAT_MEASURE,
)

print("\n" + "="*60)
print("All 3 trainers initialized!")
print(f"  Baseline checkpoints: {baseline_trainer.checkpoint_dir}")
print(f"  Beat checkpoints: {beat_trainer.checkpoint_dir}")
print(f"  Beat+Measure checkpoints: {beat_measure_trainer.checkpoint_dir}")
print("="*60)

## Step 7: Phase 2 Incremental Training

Train three models incrementally on Fold 2 to isolate where hierarchy helps:
1. **Baseline** (7-layer Bi-LSTM) - R2 ~0.19 (skip if checkpoint exists)
2. **Baseline + Beat** - R2 ~0.25-0.30 (expected beat gain: +0.10-0.15)
3. **Baseline + Beat + Measure** - R2 ~0.35-0.40 (expected measure gain: +0.05-0.10)

In [None]:
"""
PHASE 2: INCREMENTAL HIERARCHY TRAINING

Train three models incrementally to isolate where hierarchy helps:
1. Baseline (7-layer Bi-LSTM) - R2 ~0.19 (validated in Phase 1, skip if checkpoint exists)
2. Baseline + Beat hierarchy - R2 ~0.25-0.30 (expected beat gain: +0.10-0.15)
3. Baseline + Beat + Measure hierarchy - R2 ~0.35-0.40 (expected measure gain: +0.05-0.10)
"""

print("="*70)
print("PHASE 2: INCREMENTAL HIERARCHY TRAINING")
print("="*70)
print("\nGoal: Isolate where hierarchy contributes to performance")
print("Expected gains:")
print("  Beat hierarchy: +0.10 to +0.15 R2")
print("  Measure hierarchy: +0.05 to +0.10 R2")
print("  Total hierarchy gain: ~+0.21 R2")
print("="*70)

# ========================================
# Model 1: Baseline (skip if checkpoint exists)
# ========================================
print("\n" + "="*70)
print("MODEL 1: BASELINE (7-layer Bi-LSTM)")
print("="*70)
print("Expected R2: ~0.19 (matching VirtuosoNetSingle)")

# Check for existing checkpoint
baseline_checkpoint = baseline_trainer._find_checkpoint(FOLD_ID, "best")
if baseline_checkpoint:
    print(f"\nFound existing baseline checkpoint: {baseline_checkpoint}")
    print("Loading metrics from saved results...")
    
    # Try to load saved metrics from training_results.json
    import json
    results_file = baseline_trainer.checkpoint_dir / "training_results.json"
    baseline_metrics = None
    
    if results_file.exists():
        with open(results_file) as f:
            saved_results = json.load(f)
        if str(FOLD_ID) in saved_results.get('fold_metrics', {}):
            fm = saved_results['fold_metrics'][str(FOLD_ID)]
            # Load model from checkpoint using Lightning's method
            from src.percepiano.models.percepiano_replica import PercePianoBiLSTMBaseline
            baseline_model = PercePianoBiLSTMBaseline.load_from_checkpoint(
                str(baseline_checkpoint),
                strict=False,
            )
            baseline_trainer.trained_folds[FOLD_ID] = baseline_model
            
            # Create a simple namespace to hold the metrics we need
            class SimpleMetrics:
                def __init__(self, d):
                    self.val_r2 = d.get('val_r2', 0)
                    self.val_pearson = d.get('val_pearson', 0)
                    self.val_mae = d.get('val_mae', 0)
                    self.val_rmse = d.get('val_rmse', 0)
                    self.epochs_trained = d.get('epochs_trained', 0)
                    self.per_dim_r2 = d.get('per_dim_r2', {})
            
            baseline_metrics = SimpleMetrics(fm)
            print(f"Loaded baseline metrics: Val R2 = {baseline_metrics.val_r2:+.4f}")
        else:
            print("[WARNING] No saved metrics for this fold - will retrain")
    
    if baseline_metrics is None:
        print("Retraining baseline to get metrics...")
        baseline_metrics = baseline_trainer.train_fold(
            fold_id=FOLD_ID,
            verbose=True,
            resume_from_checkpoint=True,
        )
        baseline_trainer.save_results()
else:
    print("\nNo checkpoint found - training baseline...")
    baseline_metrics = baseline_trainer.train_fold(
        fold_id=FOLD_ID,
        verbose=True,
        resume_from_checkpoint=False,
    )
    baseline_trainer.save_results()
    print(f"\nBaseline training complete! Val R2 = {baseline_metrics.val_r2:+.4f}")

if baseline_metrics and baseline_metrics.val_r2 < 0.10:
    print("\n[WARNING] Baseline underperforming - check data pipeline before continuing!")
else:
    print("\n[OK] Baseline validated - proceeding to incremental hierarchy")

# ========================================
# Model 2: Baseline + Beat
# ========================================
print("\n" + "="*70)
print("MODEL 2: BASELINE + BEAT HIERARCHY")
print("="*70)
print("Expected R2: ~0.25-0.30 (beat contribution: +0.10-0.15)")

beat_metrics = beat_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
beat_trainer.save_results()

if baseline_metrics:
    beat_gain = beat_metrics.val_r2 - baseline_metrics.val_r2
    print(f"\nBaseline + Beat training complete!")
    print(f"  Val R2: {beat_metrics.val_r2:+.4f}")
    print(f"  Beat gain: {beat_gain:+.4f} (expected: +0.10 to +0.15)")
else:
    print(f"\nBaseline + Beat training complete! Val R2 = {beat_metrics.val_r2:+.4f}")

# ========================================
# Model 3: Baseline + Beat + Measure
# ========================================
print("\n" + "="*70)
print("MODEL 3: BASELINE + BEAT + MEASURE HIERARCHY")
print("="*70)
print("Expected R2: ~0.35-0.40 (approaching SOTA)")

beat_measure_metrics = beat_measure_trainer.train_fold(
    fold_id=FOLD_ID,
    verbose=True,
    resume_from_checkpoint=False,
)
beat_measure_trainer.save_results()

measure_gain = beat_measure_metrics.val_r2 - beat_metrics.val_r2
print(f"\nBaseline + Beat + Measure training complete!")
print(f"  Val R2: {beat_measure_metrics.val_r2:+.4f}")
print(f"  Measure gain: {measure_gain:+.4f} (expected: +0.05 to +0.10)")

# ========================================
# Store models and metrics for diagnostics
# ========================================
trained_models = {
    'baseline': baseline_trainer.get_trained_model(FOLD_ID),
    'baseline_beat': beat_trainer.get_trained_model(FOLD_ID),
    'baseline_beat_measure': beat_measure_trainer.get_trained_model(FOLD_ID),
}

trained_metrics = {
    'baseline': baseline_metrics,
    'baseline_beat': beat_metrics,
    'baseline_beat_measure': beat_measure_metrics,
}

# ========================================
# Quick Progress Summary
# ========================================
print("\n" + "="*70)
print("PHASE 2 TRAINING COMPLETE")
print("="*70)

if baseline_metrics:
    total_gain = beat_measure_metrics.val_r2 - baseline_metrics.val_r2
    beat_contribution = beat_metrics.val_r2 - baseline_metrics.val_r2
    measure_contribution = beat_measure_metrics.val_r2 - beat_metrics.val_r2
    
    print(f"\n  {'Model':<30} {'Val R2':>10} {'Gain':>10}")
    print(f"  {'-'*30} {'-'*10} {'-'*10}")
    print(f"  {'Baseline (7-layer BiLSTM)':<30} {baseline_metrics.val_r2:>+10.4f} {'-':>10}")
    print(f"  {'Baseline + Beat':<30} {beat_metrics.val_r2:>+10.4f} {beat_contribution:>+10.4f}")
    print(f"  {'Baseline + Beat + Measure':<30} {beat_measure_metrics.val_r2:>+10.4f} {measure_contribution:>+10.4f}")
    print(f"  {'-'*30} {'-'*10} {'-'*10}")
    print(f"  {'Total Hierarchy Gain':<30} {'-':>10} {total_gain:>+10.4f}")
    
    print(f"\n  Expected total gain: ~+0.21")
    
    if total_gain >= 0.15:
        print(f"  [SUCCESS] Hierarchy providing significant gain!")
    elif total_gain >= 0.10:
        print(f"  [GOOD] Hierarchy helping but slightly below expected")
    elif total_gain >= 0.05:
        print(f"  [PARTIAL] Hierarchy providing modest gain")
    else:
        print(f"  [ISSUE] Hierarchy not contributing enough - run diagnostics")
else:
    print("\n[WARNING] Baseline metrics not available for comparison")
    print(f"  Baseline + Beat Val R2: {beat_metrics.val_r2:+.4f}")
    print(f"  Baseline + Beat + Measure Val R2: {beat_measure_metrics.val_r2:+.4f}")

print("="*70)

## Step 8: Phase 2 Hierarchy Diagnostics

Three diagnostic checks to understand why hierarchy may or may not be contributing:
1. **span_beat_to_note_num Check**: Verify beat representations are properly distributed to notes
2. **Attention Entropy**: Check if beat/measure attention weights are near-uniform (not learning)
3. **Contractor Weight Analysis**: Check if contractor layer ignores beat/measure dimensions

In [None]:
"""
PHASE 2: POST-TRAINING HIERARCHY DIAGNOSTICS

Three diagnostic checks to understand hierarchy contribution:
1. span_beat_to_note_num: Verify beat representations properly distributed
2. Attention entropy: Check if beat/measure attention is near-uniform
3. Contractor weights: Check if contractor ignores beat/measure dimensions
"""

import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.models.hierarchy_utils import (
    span_beat_to_note_num,
    compute_actual_lengths,
)

# ========================================
# Diagnostic Functions
# ========================================

def diagnose_span_beat_to_note_num(beat_numbers, device):
    """
    Verify beat representations are correctly mapped back to note level.
    
    CHECKS:
    1. Beat indices are monotonically non-decreasing within valid positions
    2. Zero-shifted indices have no negative values
    3. Beat-to-note mapping produces correct shapes
    """
    results = {'status': 'OK', 'issues': []}
    
    actual_lengths = compute_actual_lengths(beat_numbers)
    
    # Check 1: Beat indices are valid (monotonically non-decreasing)
    for i in range(beat_numbers.shape[0]):
        valid_len = actual_lengths[i].item()
        if valid_len > 1:
            valid_beats = beat_numbers[i, :valid_len]
            diffs = valid_beats[1:] - valid_beats[:-1]
            if (diffs < 0).any():
                results['issues'].append(f"Sample {i}: Non-monotonic beat indices")
                results['status'] = 'FAIL'
    
    # Check 2: Zero-shifted values
    first_beats = beat_numbers[:, 0:1]
    zero_shifted = beat_numbers - first_beats
    for i in range(beat_numbers.shape[0]):
        valid_len = actual_lengths[i].item()
        if valid_len > 0 and (zero_shifted[i, :valid_len] < 0).any():
            results['issues'].append(f"Sample {i}: Negative zero-shifted values")
            results['status'] = 'FAIL'
    
    # Check 3: Test span_beat_to_note_num function
    batch_size, num_notes = beat_numbers.shape
    max_beat = beat_numbers.max().item()
    if max_beat > 0:
        num_beats = max_beat + 1
        dummy_beat_out = torch.randn(batch_size, num_beats, 512, device=device)
        
        try:
            spanned = span_beat_to_note_num(dummy_beat_out, beat_numbers, actual_lengths)
            results['spanned_shape'] = tuple(spanned.shape)
            results['spanned_non_zero'] = (spanned.abs() > 1e-6).any().item()
            
            if not results['spanned_non_zero']:
                results['issues'].append("span_beat_to_note_num output is all zeros!")
                results['status'] = 'FAIL'
        except Exception as e:
            results['issues'].append(f"span_beat_to_note_num failed: {str(e)}")
            results['status'] = 'FAIL'
    
    return results


def diagnose_attention_entropy(model, x_embedded, actual_lengths, model_name):
    """
    Analyze attention weight distribution in hierarchical components.
    
    CHECKS:
    1. Beat attention entropy (normalized 0-1, where 1=uniform)
    2. Whether attention is too uniform (>0.95) or too collapsed (<0.1)
    """
    results = {'beat_entropy': None, 'status': 'OK', 'issues': []}
    
    if not hasattr(model, 'beat_attention'):
        results['status'] = 'SKIPPED'
        results['issues'].append(f"{model_name} has no beat_attention")
        return results
    
    try:
        # Run through LSTM to get hidden states
        x_packed = pack_padded_sequence(
            x_embedded, actual_lengths.cpu().clamp(min=1),
            batch_first=True, enforce_sorted=False
        )
        lstm_out, _ = model.lstm(x_packed)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=x_embedded.shape[1])
        
        # Get beat attention weights
        beat_similarity = model.beat_attention.get_attention(lstm_out)  # [B, T, num_head]
        
        # Apply softmax with temperature
        temp = getattr(model.beat_attention, 'temperature', 1.0)
        beat_attention = torch.softmax(beat_similarity / temp, dim=1)
        
        # Compute entropy (averaged over batch and heads)
        eps = 1e-10
        # Shape: [B, T, H] -> entropy per head per sample
        beat_entropy = -torch.sum(beat_attention * torch.log(beat_attention + eps), dim=1)  # [B, H]
        
        # Normalize by max possible entropy (log of sequence length)
        max_entropy = torch.log(torch.tensor(float(lstm_out.shape[1]), device=lstm_out.device))
        normalized_entropy = (beat_entropy / max_entropy).mean().item()
        
        results['beat_entropy'] = normalized_entropy
        results['raw_entropy'] = beat_entropy.mean().item()
        results['max_entropy'] = max_entropy.item()
        
        # Check for issues
        if normalized_entropy > 0.95:
            results['issues'].append(f"Beat attention near-uniform (entropy={normalized_entropy:.3f})")
            results['status'] = 'WARNING'
        elif normalized_entropy < 0.1:
            results['issues'].append(f"Beat attention collapsed (entropy={normalized_entropy:.3f})")
            results['status'] = 'WARNING'
            
    except Exception as e:
        results['issues'].append(f"Attention analysis failed: {str(e)}")
        results['status'] = 'ERROR'
    
    return results


def diagnose_contractor_weights(model, model_name):
    """
    Analyze contractor weights to check if hierarchy dimensions are being used.
    
    CHECKS:
    1. Weight magnitude for LSTM vs beat vs measure input dimensions
    2. Whether hierarchy weights are significantly smaller (model ignoring them)
    """
    results = {
        'lstm_weight_mag': None,
        'beat_weight_mag': None,
        'measure_weight_mag': None,
        'ratio_beat_to_lstm': None,
        'ratio_measure_to_lstm': None,
        'status': 'OK',
        'issues': [],
    }
    
    if not hasattr(model, 'note_contractor'):
        results['status'] = 'SKIPPED'
        results['issues'].append(f"{model_name} has no note_contractor")
        return results
    
    contractor = model.note_contractor.weight.data  # [out_dim, in_dim]
    in_dim = contractor.shape[1]
    
    # Determine layout based on input dimension
    if in_dim == 1536:
        # Baseline + Beat + Measure: [LSTM:512, beat:512, measure:512]
        results['lstm_weight_mag'] = contractor[:, :512].abs().mean().item()
        results['beat_weight_mag'] = contractor[:, 512:1024].abs().mean().item()
        results['measure_weight_mag'] = contractor[:, 1024:].abs().mean().item()
        
        results['ratio_beat_to_lstm'] = results['beat_weight_mag'] / (results['lstm_weight_mag'] + 1e-8)
        results['ratio_measure_to_lstm'] = results['measure_weight_mag'] / (results['lstm_weight_mag'] + 1e-8)
        
        if results['ratio_beat_to_lstm'] < 0.1:
            results['issues'].append(f"Contractor ignoring beat (ratio={results['ratio_beat_to_lstm']:.3f})")
            results['status'] = 'WARNING'
        if results['ratio_measure_to_lstm'] < 0.1:
            results['issues'].append(f"Contractor ignoring measure (ratio={results['ratio_measure_to_lstm']:.3f})")
            results['status'] = 'WARNING'
            
    elif in_dim == 1024:
        # Baseline + Beat: [LSTM:512, beat:512]
        results['lstm_weight_mag'] = contractor[:, :512].abs().mean().item()
        results['beat_weight_mag'] = contractor[:, 512:].abs().mean().item()
        
        results['ratio_beat_to_lstm'] = results['beat_weight_mag'] / (results['lstm_weight_mag'] + 1e-8)
        
        if results['ratio_beat_to_lstm'] < 0.1:
            results['issues'].append(f"Contractor ignoring beat (ratio={results['ratio_beat_to_lstm']:.3f})")
            results['status'] = 'WARNING'
    else:
        results['status'] = 'SKIPPED'
        results['issues'].append(f"Unexpected contractor input dim: {in_dim}")
    
    return results


# ========================================
# Run Diagnostics
# ========================================

print("="*70)
print("PHASE 2 HIERARCHY DIAGNOSTICS")
print("="*70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare validation batch
val_ds = PercePianoKFoldDataset(
    data_dir=DATA_ROOT,
    fold_assignments=fold_assignments,
    fold_id=FOLD_ID,
    mode="val",
    max_notes=CONFIG['max_notes'],
    slice_len=CONFIG.get('slice_len', CONFIG['max_notes']),
)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0)
batch = next(iter(val_loader))

# Move batch to device
input_features = batch['input_features'].to(device)
beat_numbers = batch['note_locations_beat'].to(device)
measure_numbers = batch['note_locations_measure'].to(device)
actual_lengths = compute_actual_lengths(beat_numbers)

# Models to diagnose (only the ones with hierarchy)
models_to_diagnose = [
    ('baseline_beat', trained_models.get('baseline_beat')),
    ('baseline_beat_measure', trained_models.get('baseline_beat_measure')),
]

for model_name, model in models_to_diagnose:
    if model is None:
        print(f"\n[SKIPPED] {model_name} not available")
        continue
    
    model = model.to(device).eval()
    
    print(f"\n{'='*70}")
    print(f"DIAGNOSTICS: {model_name.upper()}")
    print(f"{'='*70}")
    
    # Prepare embedded input for attention analysis
    with torch.no_grad():
        x_embedded = model.note_embedder(input_features)
    
    # Diagnostic 1: span_beat_to_note_num check
    print("\n[1] SPAN_BEAT_TO_NOTE_NUM CHECK:")
    span_results = diagnose_span_beat_to_note_num(beat_numbers, device)
    print(f"    Status: {span_results['status']}")
    if span_results.get('spanned_shape'):
        print(f"    Spanned output shape: {span_results['spanned_shape']}")
        print(f"    Non-zero output: {span_results['spanned_non_zero']}")
    if span_results['issues']:
        for issue in span_results['issues']:
            print(f"    - {issue}")
    else:
        print("    All checks passed")
    
    # Diagnostic 2: Attention entropy
    print("\n[2] ATTENTION ENTROPY CHECK:")
    with torch.no_grad():
        entropy_results = diagnose_attention_entropy(model, x_embedded, actual_lengths, model_name)
    
    if entropy_results['beat_entropy'] is not None:
        print(f"    Beat attention entropy: {entropy_results['beat_entropy']:.3f} (1.0=uniform, 0.0=focused)")
        print(f"    Raw entropy: {entropy_results['raw_entropy']:.3f}, Max possible: {entropy_results['max_entropy']:.3f}")
        
        if entropy_results['beat_entropy'] > 0.8:
            print("    [WARNING] Attention is relatively uniform - may not be learning to focus")
        elif entropy_results['beat_entropy'] < 0.3:
            print("    [GOOD] Attention is focused on specific positions")
        else:
            print("    [OK] Attention entropy in reasonable range")
    
    print(f"    Status: {entropy_results['status']}")
    if entropy_results['issues']:
        for issue in entropy_results['issues']:
            print(f"    - {issue}")
    
    # Diagnostic 3: Contractor weights
    print("\n[3] CONTRACTOR WEIGHT ANALYSIS:")
    contractor_results = diagnose_contractor_weights(model, model_name)
    
    if contractor_results['lstm_weight_mag'] is not None:
        print(f"    LSTM weight magnitude: {contractor_results['lstm_weight_mag']:.4f}")
        print(f"    Beat weight magnitude: {contractor_results['beat_weight_mag']:.4f}")
        if contractor_results['measure_weight_mag'] is not None:
            print(f"    Measure weight magnitude: {contractor_results['measure_weight_mag']:.4f}")
        
        print(f"    Beat/LSTM ratio: {contractor_results['ratio_beat_to_lstm']:.3f}")
        if contractor_results['ratio_measure_to_lstm'] is not None:
            print(f"    Measure/LSTM ratio: {contractor_results['ratio_measure_to_lstm']:.3f}")
        
        # Interpretation
        if contractor_results['ratio_beat_to_lstm'] >= 0.5:
            print("    [GOOD] Contractor using beat branch effectively")
        elif contractor_results['ratio_beat_to_lstm'] >= 0.2:
            print("    [OK] Contractor partially using beat branch")
        else:
            print("    [WARNING] Contractor may be ignoring beat branch")
    
    print(f"    Status: {contractor_results['status']}")
    if contractor_results['issues']:
        for issue in contractor_results['issues']:
            print(f"    - {issue}")

print("\n" + "="*70)
print("DIAGNOSTICS COMPLETE")
print("="*70)
print("\nInterpretation guide:")
print("  - span_beat_to_note_num: Should produce non-zero output with correct shapes")
print("  - Attention entropy: 0.3-0.8 is good; >0.95 means uniform (not learning)")
print("  - Contractor ratio: >0.2 means hierarchy is being used; <0.1 means ignored")

## Step 9: Comprehensive Analysis

Single analysis cell comparing all 3 incremental models with 5 sections:
1. Overall metrics table
2. Per-dimension R2 comparison
3. Hierarchy contribution analysis
4. Comparison to PercePiano paper
5. Next steps recommendations

In [None]:
"""
PHASE 2: COMPREHENSIVE ANALYSIS

Compare all three incremental models and analyze hierarchy contribution.
"""

import numpy as np
from sklearn.metrics import r2_score
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

print("="*80)
print("PHASE 2 COMPREHENSIVE ANALYSIS")
print("="*80)

# Validate we have metrics
if 'trained_metrics' not in dir():
    raise RuntimeError("No trained_metrics found - run training cell first!")

baseline_m = trained_metrics.get('baseline')
beat_m = trained_metrics.get('baseline_beat')
beat_measure_m = trained_metrics.get('baseline_beat_measure')

if not all([baseline_m, beat_m, beat_measure_m]):
    missing = [k for k, v in [('baseline', baseline_m), ('baseline_beat', beat_m), 
                               ('baseline_beat_measure', beat_measure_m)] if v is None]
    print(f"\n[WARNING] Missing metrics for: {missing}")
    print("Some analysis sections will be incomplete.")

# ========================================
# Section 1: Overall Metrics Table
# ========================================
print("\n" + "-"*80)
print("SECTION 1: OVERALL METRICS")
print("-"*80)

print(f"\n  {'Model':<32} {'Val R2':>10} {'Epochs':>8} {'Gain':>10}")
print(f"  {'-'*32} {'-'*10} {'-'*8} {'-'*10}")

if baseline_m:
    print(f"  {'Baseline (7-layer BiLSTM)':<32} {baseline_m.val_r2:>+10.4f} {baseline_m.epochs_trained:>8} {'-':>10}")

if beat_m:
    beat_gain = beat_m.val_r2 - baseline_m.val_r2 if baseline_m else 0
    print(f"  {'Baseline + Beat':<32} {beat_m.val_r2:>+10.4f} {beat_m.epochs_trained:>8} {beat_gain:>+10.4f}")

if beat_measure_m:
    measure_gain = beat_measure_m.val_r2 - beat_m.val_r2 if beat_m else 0
    print(f"  {'Baseline + Beat + Measure':<32} {beat_measure_m.val_r2:>+10.4f} {beat_measure_m.epochs_trained:>8} {measure_gain:>+10.4f}")

if baseline_m and beat_measure_m:
    total_gain = beat_measure_m.val_r2 - baseline_m.val_r2
    print(f"  {'-'*32} {'-'*10} {'-'*8} {'-'*10}")
    print(f"  {'Total Hierarchy Gain':<32} {'-':>10} {'-':>8} {total_gain:>+10.4f}")
    print(f"\n  Expected total gain: ~+0.21")

# ========================================
# Section 2: Per-Dimension R2 Comparison
# ========================================
print("\n" + "-"*80)
print("SECTION 2: PER-DIMENSION R2 COMPARISON")
print("-"*80)

# Check if we have per-dimension metrics
has_per_dim = (baseline_m and hasattr(baseline_m, 'per_dim_r2') and baseline_m.per_dim_r2 and
               beat_m and hasattr(beat_m, 'per_dim_r2') and beat_m.per_dim_r2 and
               beat_measure_m and hasattr(beat_measure_m, 'per_dim_r2') and beat_measure_m.per_dim_r2)

if has_per_dim:
    print(f"\n  {'Dimension':<22} {'Baseline':>10} {'+Beat':>10} {'+Beat+Meas':>12} {'Beat Gain':>10} {'Meas Gain':>10}")
    print(f"  {'-'*22} {'-'*10} {'-'*10} {'-'*12} {'-'*10} {'-'*10}")
    
    dim_data = []
    for dim in PERCEPIANO_DIMENSIONS:
        b_r2 = baseline_m.per_dim_r2.get(dim, 0)
        bb_r2 = beat_m.per_dim_r2.get(dim, 0)
        bbm_r2 = beat_measure_m.per_dim_r2.get(dim, 0)
        beat_gain = bb_r2 - b_r2
        meas_gain = bbm_r2 - bb_r2
        dim_data.append((dim, b_r2, bb_r2, bbm_r2, beat_gain, meas_gain))
    
    # Sort by final model R2 (best first)
    dim_data.sort(key=lambda x: x[3], reverse=True)
    
    for dim, b_r2, bb_r2, bbm_r2, beat_gain, meas_gain in dim_data:
        print(f"  {dim:<22} {b_r2:>+10.4f} {bb_r2:>+10.4f} {bbm_r2:>+12.4f} {beat_gain:>+10.4f} {meas_gain:>+10.4f}")
    
    # Summary
    positive_beat_gains = sum(1 for _, _, _, _, bg, _ in dim_data if bg > 0)
    positive_meas_gains = sum(1 for _, _, _, _, _, mg in dim_data if mg > 0)
    print(f"\n  Beat helps {positive_beat_gains}/{len(PERCEPIANO_DIMENSIONS)} dimensions")
    print(f"  Measure helps {positive_meas_gains}/{len(PERCEPIANO_DIMENSIONS)} dimensions")
else:
    print("\n  [Per-dimension R2 not available in saved metrics]")
    print("  To enable: ensure trainer saves per_dim_r2 in FoldMetrics")

# ========================================
# Section 3: Hierarchy Contribution Analysis
# ========================================
print("\n" + "-"*80)
print("SECTION 3: HIERARCHY CONTRIBUTION ANALYSIS")
print("-"*80)

if baseline_m and beat_m and beat_measure_m:
    beat_contribution = beat_m.val_r2 - baseline_m.val_r2
    measure_contribution = beat_measure_m.val_r2 - beat_m.val_r2
    total_hierarchy = beat_contribution + measure_contribution
    
    # Expected values from PercePiano paper
    expected_beat = 0.15  # approximate
    expected_measure = 0.06  # approximate
    expected_total = 0.21
    
    print(f"\n  {'Component':<30} {'Actual':>10} {'Expected':>10} {'%Achieved':>12}")
    print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
    print(f"  {'Beat hierarchy contribution':<30} {beat_contribution:>+10.4f} {'+0.10-0.15':>10} {100*beat_contribution/expected_beat:>11.0f}%")
    print(f"  {'Measure hierarchy contribution':<30} {measure_contribution:>+10.4f} {'+0.05-0.10':>10} {100*measure_contribution/expected_measure:>11.0f}%")
    print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
    print(f"  {'Total hierarchy gain':<30} {total_hierarchy:>+10.4f} {'~+0.21':>10} {100*total_hierarchy/expected_total:>11.0f}%")
    
    # Interpretation
    print("\n  Interpretation:")
    if beat_contribution >= 0.10:
        print("    - Beat hierarchy: SIGNIFICANT contribution [OK]")
    elif beat_contribution >= 0.05:
        print("    - Beat hierarchy: PARTIAL contribution [CHECK DIAGNOSTICS]")
    else:
        print("    - Beat hierarchy: MINIMAL contribution [INVESTIGATE]")
    
    if measure_contribution >= 0.05:
        print("    - Measure hierarchy: SIGNIFICANT contribution [OK]")
    elif measure_contribution >= 0.02:
        print("    - Measure hierarchy: PARTIAL contribution [CHECK DIAGNOSTICS]")
    else:
        print("    - Measure hierarchy: MINIMAL contribution [INVESTIGATE]")

# ========================================
# Section 4: Comparison to PercePiano Paper
# ========================================
print("\n" + "-"*80)
print("SECTION 4: COMPARISON TO PERCEPIANO PAPER")
print("-"*80)

print(f"\n  {'Model':<32} {'Paper R2':>10} {'Our R2':>10} {'Match':>10}")
print(f"  {'-'*32} {'-'*10} {'-'*10} {'-'*10}")

if baseline_m:
    status = "[OK]" if baseline_m.val_r2 >= 0.15 else "[LOW]"
    print(f"  {'VirtuosoNetSingle (Baseline)':<32} {'0.185':>10} {baseline_m.val_r2:>+10.4f} {status:>10}")

if beat_measure_m:
    if beat_measure_m.val_r2 >= 0.35:
        status = "[OK]"
    elif beat_measure_m.val_r2 >= 0.30:
        status = "[CLOSE]"
    elif beat_measure_m.val_r2 >= 0.25:
        status = "[PARTIAL]"
    else:
        status = "[LOW]"
    print(f"  {'VirtuosoNetMultiLevel (Full)':<32} {'0.397':>10} {beat_measure_m.val_r2:>+10.4f} {status:>10}")

if baseline_m and beat_measure_m:
    total_gain = beat_measure_m.val_r2 - baseline_m.val_r2
    status = "[GOOD]" if total_gain >= 0.15 else "[PARTIAL]" if total_gain >= 0.05 else "[NONE]"
    print(f"  {'Hierarchy Gain':<32} {'+0.212':>10} {total_gain:>+10.4f} {status:>10}")

# ========================================
# Section 5: Next Steps Recommendations
# ========================================
print("\n" + "-"*80)
print("SECTION 5: NEXT STEPS RECOMMENDATIONS")
print("-"*80)

if baseline_m and beat_m and beat_measure_m:
    beat_contribution = beat_m.val_r2 - baseline_m.val_r2
    measure_contribution = beat_measure_m.val_r2 - beat_m.val_r2
    total_hierarchy = beat_contribution + measure_contribution
    
    print("\n  Based on Phase 2 results:")
    
    if total_hierarchy >= 0.15:
        print("\n  [SUCCESS] Hierarchy contributing as expected!")
        print("    - Ready for Phase 3: MERT embeddings, Conformer, etc.")
        print("    - Consider training on all 4 folds for robust evaluation")
    elif total_hierarchy >= 0.10:
        print("\n  [GOOD] Hierarchy helping but slightly below expected")
        print("    - Check diagnostics for attention entropy and contractor weights")
        print("    - May proceed to Phase 3 or investigate further")
    elif beat_contribution < 0.05:
        print("\n  [ACTION NEEDED] Beat hierarchy underperforming")
        print("    Recommendations:")
        print("    1. Check beat attention entropy - if >0.95, attention not learning")
        print("    2. Check span_beat_to_note_num - verify beat indices are valid")
        print("    3. Check contractor beat/lstm ratio - if <0.1, model ignoring beat")
        print("    4. Try longer training or different learning rate")
    elif measure_contribution < 0.02:
        print("\n  [ACTION NEEDED] Measure hierarchy not contributing")
        print("    Recommendations:")
        print("    1. Check measure attention entropy")
        print("    2. Verify beat-to-measure boundary detection")
        print("    3. Check contractor measure/lstm ratio")
        print("    4. Consider piece length distribution in fold")
    else:
        print("\n  [PARTIAL] Some hierarchy contribution but below target")
        print("    - Review diagnostics for specific issues")
        print("    - Consider hyperparameter tuning")

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)

## Step 10: Sync Checkpoints to Google Drive

Single checkpoint sync cell for all Phase 2 models.

In [None]:
"""
PHASE 2: SYNC ALL CHECKPOINTS TO GOOGLE DRIVE
"""

import subprocess

print("="*60)
print("SYNC ALL PHASE 2 CHECKPOINTS TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    # Sync all model checkpoints
    model_dirs = [
        ('baseline', baseline_trainer.checkpoint_dir),
        ('baseline_beat', beat_trainer.checkpoint_dir),
        ('baseline_beat_measure', beat_measure_trainer.checkpoint_dir),
    ]
    
    for model_name, ckpt_dir in model_dirs:
        if ckpt_dir.exists():
            gdrive_path = f"{GDRIVE_CHECKPOINT_PATH}/{ckpt_dir.name}"
            print(f"\nSyncing {model_name} ({ckpt_dir.name})...")
            subprocess.run(
                ['rclone', 'copy', str(ckpt_dir), gdrive_path, '--progress'],
                capture_output=False
            )
        else:
            print(f"\n[SKIP] {model_name} checkpoint dir not found: {ckpt_dir}")
    
    # Sync fold assignments
    print(f"\nSyncing fold assignments...")
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print("\n" + "="*60)
    print("SYNC COMPLETE")
    print("="*60)
    print(f"\nCheckpoints synced to: {GDRIVE_CHECKPOINT_PATH}")
    print(f"Fold assignments synced to: {GDRIVE_DATA_PATH}")
else:
    print("\nrclone not available - skipping sync")
    print("Checkpoints saved locally at:", CHECKPOINT_ROOT)

print("\n" + "="*60)
print("PHASE 2 COMPLETE")
print("="*60)