# PercePiano Replica Training - Phase 2 Incremental Build

**Upload and Run**: Configure the cell below, then Run All.

## Phase 2 Hierarchy Isolation
```
VirtuosoNetSingle (7-layer flat)     -> R2 = 0.19 (validated)
    + Beat hierarchy                  -> R2 = ??? (target: ~0.25-0.30)
    + Measure hierarchy               -> R2 = ??? (target: ~0.35-0.40)
Full HAN (VirtuosoNetMultiLevel)     -> R2 = 0.40 (SOTA target)
```

## Model Types
| Model | Architecture | Expected R2 |
|-------|-------------|-------------|
| `baseline` | 7-layer BiLSTM | ~0.19 |
| `baseline_beat` | 7-layer BiLSTM + Beat hierarchy | ~0.25-0.30 |
| `baseline_beat_measure` | 7-layer BiLSTM + Beat + Measure | ~0.35-0.40 |
| `han` | Full HAN (note+voice+beat+measure) | ~0.40 |

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install uv and clone repository
!curl -LsSf https://astral.sh/uv/install.sh | sh

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install dependencies
!uv pip install --system -e .
!pip install tensorboard rich

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

In [None]:
# Training Config Vars

# Which models to train (Phase 2 hierarchy isolation)
TRAIN_BASELINE = True
TRAIN_BASELINE_BEAT = True          # + Beat hierarchy (~0.25-0.30)
TRAIN_BASELINE_BEAT_MEASURE = True  # + Measure hierarchy (~0.35-0.40)
TRAIN_HAN = True

# Training settings
FOLD_ID = 2
MAX_EPOCHS = 200
EARLY_STOPPING_PATIENCE = 20
RESTART_TRAINING = True

print("Configuration:")
print(f"  Models: baseline={TRAIN_BASELINE}, beat={TRAIN_BASELINE_BEAT}, beat+measure={TRAIN_BASELINE_BEAT_MEASURE}, han={TRAIN_HAN}")
print(f"  Fold: {FOLD_ID}, Epochs: {MAX_EPOCHS}, Patience: {EARLY_STOPPING_PATIENCE}")

## Step 2: Configure Paths

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

# Paths
DATA_ROOT = Path('/tmp/percepiano_vnet_84dim')
CHECKPOINT_ROOT = Path('/tmp/checkpoints/percepiano_kfold')
LOG_ROOT = Path('/tmp/logs/percepiano_kfold')
GDRIVE_DATA_PATH = 'gdrive:crescendai_data/percepiano_vnet_84dim'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/percepiano_kfold'

print("="*60)
print("PERCEPIANO REPLICA TRAINING - PHASE 2")
print("="*60)

# Clear checkpoints if restarting (uses RESTART_TRAINING from config)
if RESTART_TRAINING and CHECKPOINT_ROOT.exists():
    print(f"\nRESTART_TRAINING=True: Clearing checkpoints")
    shutil.rmtree(CHECKPOINT_ROOT)

if RESTART_TRAINING and LOG_ROOT.exists():
    shutil.rmtree(LOG_ROOT)

# Create directories
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
RCLONE_AVAILABLE = 'gdrive:' in result.stdout
print(f"\nrclone gdrive: {'CONFIGURED' if RCLONE_AVAILABLE else 'NOT CONFIGURED'}")
print(f"Data: {DATA_ROOT}")
print(f"Checkpoints: {CHECKPOINT_ROOT}")

## Step 3: Download Data

In [None]:
if not RCLONE_AVAILABLE:
    raise RuntimeError("rclone not configured. Run 'rclone config' first.")

print("Downloading data from Google Drive...")
subprocess.run(['rclone', 'copy', GDRIVE_DATA_PATH, str(DATA_ROOT), '--progress'], capture_output=False)

# Verify
total = sum(1 for _ in DATA_ROOT.glob('**/*.pkl'))
print(f"\nTotal samples: {total}")

## Step 4: Create Fold Assignments

In [None]:
from src.percepiano.data.kfold_split import (
    create_piece_based_folds,
    save_fold_assignments,
    load_fold_assignments,
    print_fold_statistics,
)

FOLD_FILE = DATA_ROOT / 'fold_assignments.json'
N_FOLDS = 4
SEED = 42

if FOLD_FILE.exists():
    fold_assignments = load_fold_assignments(FOLD_FILE)
    print("Loaded existing fold assignments")
else:
    fold_assignments = create_piece_based_folds(DATA_ROOT, N_FOLDS, test_ratio=0.15, seed=SEED)
    save_fold_assignments(fold_assignments, FOLD_FILE)
    print("Created new fold assignments")

print_fold_statistics(fold_assignments, n_folds=N_FOLDS)

## Step 5: Training Configuration

In [None]:
import torch
torch.set_float32_matmul_precision('medium')

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from src.percepiano.training.kfold_trainer import (
    MODEL_TYPE_BASELINE,
    MODEL_TYPE_BASELINE_BEAT,
    MODEL_TYPE_BASELINE_BEAT_MEASURE,
    MODEL_TYPE_HAN,
)

# Build MODELS_TO_TRAIN from config booleans
MODELS_TO_TRAIN = []
if TRAIN_BASELINE:
    MODELS_TO_TRAIN.append(MODEL_TYPE_BASELINE)
if TRAIN_BASELINE_BEAT:
    MODELS_TO_TRAIN.append(MODEL_TYPE_BASELINE_BEAT)
if TRAIN_BASELINE_BEAT_MEASURE:
    MODELS_TO_TRAIN.append(MODEL_TYPE_BASELINE_BEAT_MEASURE)
if TRAIN_HAN:
    MODELS_TO_TRAIN.append(MODEL_TYPE_HAN)

CONFIG = {
    'n_folds': N_FOLDS,
    'test_ratio': 0.15,
    'data_dir': str(DATA_ROOT),
    'checkpoint_dir': str(CHECKPOINT_ROOT),
    'log_dir': str(LOG_ROOT),
    'input_size': 79,
    'hidden_size': 256,
    'note_layers': 2,
    'voice_layers': 2,
    'beat_layers': 2,
    'measure_layers': 1,
    'num_attention_heads': 8,
    'learning_rate': 2.5e-5,
    'weight_decay': 1e-5,
    'dropout': 0.2,
    'batch_size': 8,
    'max_epochs': MAX_EPOCHS,
    'early_stopping_patience': EARLY_STOPPING_PATIENCE,
    'gradient_clip_val': 2.0,
    'precision': '32',
    'max_notes': 5000,
    'slice_len': 5000,
    'num_workers': 4,
    'augment_train': False,
}

print("="*60)
print("TRAINING PLAN")
print("="*60)
print(f"\nModels to train on Fold {FOLD_ID}:")
for m in MODELS_TO_TRAIN:
    expected = {"baseline": "~0.19", "baseline_beat": "~0.25-0.30", 
                "baseline_beat_measure": "~0.35-0.40", "han": "~0.40"}.get(m, "?")
    print(f"  - {m} (expected R2: {expected})")
print(f"\nMax epochs: {MAX_EPOCHS}, Early stopping: {EARLY_STOPPING_PATIENCE}")

## Step 6: Pre-Training Data Diagnostics

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from src.percepiano.data.percepiano_vnet_dataset import PercePianoKFoldDataset
from src.percepiano.training.diagnostics import analyze_indices

print("="*60)
print("PRE-TRAINING DATA DIAGNOSTICS")
print("="*60)

test_ds = PercePianoKFoldDataset(
    data_dir=DATA_ROOT, fold_assignments=fold_assignments, fold_id=0, mode="train",
    max_notes=CONFIG['max_notes'], slice_len=CONFIG['slice_len'],
)
batch = next(iter(DataLoader(test_ds, batch_size=4, shuffle=False)))

print(f"\nBatch shapes:")
print(f"  input_features: {batch['input_features'].shape}")
print(f"  beat indices: {batch['note_locations_beat'].shape}")
print(f"  measure indices: {batch['note_locations_measure'].shape}")

idx_stats = analyze_indices(batch['note_locations_beat'], batch['note_locations_measure'])
print(f"\nIndex analysis:")
print(f"  Beat range: [{idx_stats['beat_min']}, {idx_stats['beat_max']}]")
print(f"  Measure range: [{idx_stats['measure_min']}, {idx_stats['measure_max']}]")

issues = []
if idx_stats['beat_min'] != 1:
    issues.append(f"Beat indices start from {idx_stats['beat_min']} (expected 1)")
if idx_stats['negative_beat_count'] > 0:
    issues.append(f"{idx_stats['negative_beat_count']} negative zero-shifted beat values")

if issues:
    print(f"\n[ISSUES FOUND]")
    for issue in issues:
        print(f"  - {issue}")
else:
    print(f"\n[OK] Data pipeline looks correct")

## Step 7: Train Model

Select model type and folds to train below.

In [None]:
from src.percepiano.training.kfold_trainer import KFoldTrainer
import pytorch_lightning as pl

pl.seed_everything(42, workers=True)

# Uses FOLD_ID and MODELS_TO_TRAIN from config cells above
all_results = {}

for MODEL_TYPE in MODELS_TO_TRAIN:
    print("\n" + "="*70)
    print(f"TRAINING: {MODEL_TYPE} on Fold {FOLD_ID}")
    print("="*70)
    
    trainer = KFoldTrainer(
        config=CONFIG,
        fold_assignments=fold_assignments,
        data_dir=DATA_ROOT,
        checkpoint_dir=CHECKPOINT_ROOT,
        log_dir=LOG_ROOT,
        n_folds=N_FOLDS,
        model_type=MODEL_TYPE,
    )
    
    metrics = trainer.train_fold(fold_id=FOLD_ID, verbose=True, resume_from_checkpoint=False)
    trainer.save_results()
    
    all_results[MODEL_TYPE] = {
        'r2': metrics.val_r2,
        'pearson': metrics.val_pearson,
        'mae': metrics.val_mae,
        'epochs': metrics.epochs_trained,
        'trainer': trainer,
    }
    
    print(f"\n  {MODEL_TYPE}: R2 = {metrics.val_r2:+.4f}")

# Summary comparison
print("\n" + "="*70)
print(f"PHASE 2 RESULTS COMPARISON (Fold {FOLD_ID})")
print("="*70)

baseline_r2 = 0.1931  # Validated baseline

print(f"\n  {'Model':<30} {'R2':>10} {'Gain':>10} {'Expected':>12}")
print(f"  {'-'*30} {'-'*10} {'-'*10} {'-'*12}")
print(f"  {'baseline (reference)':<30} {'+0.1931':>10} {'---':>10} {'~0.19':>12}")

for model_type, result in all_results.items():
    gain = result['r2'] - baseline_r2
    expected = {
        MODEL_TYPE_BASELINE: "~0.19",
        MODEL_TYPE_BASELINE_BEAT: "~0.25-0.30",
        MODEL_TYPE_BASELINE_BEAT_MEASURE: "~0.35-0.40",
        MODEL_TYPE_HAN: "~0.40",
    }.get(model_type, "?")
    print(f"  {model_type:<30} {result['r2']:>+10.4f} {gain:>+10.4f} {expected:>12}")

print("="*70)

## Step 8: Post-Training Analysis

Copy/paste these results for tracking.

In [None]:
import torch
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from torch.nn.utils.rnn import PackedSequence

from src.percepiano.data.percepiano_vnet_dataset import (
    PercePianoKFoldDataset,
    percepiano_pack_collate,
)
from src.percepiano.models.percepiano_replica import PERCEPIANO_DIMENSIONS

# Analyze last trained model
if not MODELS_TO_TRAIN:
    print("[ERROR] No models were trained. Check config.")
else:
    LAST_MODEL_TYPE = MODELS_TO_TRAIN[-1]
    
    print("="*70)
    print(f"POST-TRAINING ANALYSIS: {LAST_MODEL_TYPE}")
    print("="*70)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Get trained model from last trainer
    model = None
    if 'trainer' in dir() and trainer is not None:
        model = trainer.get_trained_model(FOLD_ID)
        if model:
            model = model.to(device).eval()
    
    if model is None:
        print("\n[ERROR] No model available. Run training first.")
    else:
        # Create validation dataloader
        val_ds = PercePianoKFoldDataset(
            data_dir=DATA_ROOT, fold_assignments=fold_assignments, fold_id=FOLD_ID,
            mode="val", max_notes=CONFIG['max_notes'], slice_len=CONFIG['slice_len'],
        )
        val_loader = DataLoader(
            val_ds, batch_size=4, shuffle=False, num_workers=0, collate_fn=percepiano_pack_collate,
        )
        
        # Collect predictions
        all_preds, all_targets = [], []
        
        with torch.no_grad():
            for batch in val_loader:
                batch_on_device = {}
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        batch_on_device[k] = v.to(device)
                    elif isinstance(v, PackedSequence):
                        batch_on_device[k] = PackedSequence(
                            v.data.to(device), v.batch_sizes,
                            v.sorted_indices.to(device) if v.sorted_indices is not None else None,
                            v.unsorted_indices.to(device) if v.unsorted_indices is not None else None,
                        )
                    else:
                        batch_on_device[k] = v
                
                note_locations = {
                    'beat': batch_on_device['note_locations_beat'],
                    'measure': batch_on_device['note_locations_measure'],
                    'voice': batch_on_device['note_locations_voice'],
                }
                
                outputs = model(
                    batch_on_device['input_features'], note_locations,
                    batch_on_device.get('attention_mask'), batch_on_device.get('lengths'),
                )
                all_preds.append(outputs['predictions'].cpu())
                all_targets.append(batch_on_device['scores'].cpu())
        
        preds = torch.cat(all_preds).numpy()
        targets = torch.cat(all_targets).numpy()
        
        # Overall metrics
        r2 = r2_score(targets, preds)
        pearson_r = np.mean([pearsonr(targets[:, i], preds[:, i])[0] for i in range(19)])
        mae = np.mean(np.abs(targets - preds))
        
        print(f"\n  Model:   {LAST_MODEL_TYPE}")
        print(f"  R2:      {r2:+.4f}")
        print(f"  Pearson: {pearson_r:+.4f}")
        print(f"  MAE:     {mae:.4f}")
        
        # Prediction health
        pred_std = preds.std()
        target_std = targets.std()
        print(f"\n  Prediction std: {pred_std:.4f} (target: {target_std:.4f})")
        if pred_std < target_std * 0.5:
            print(f"  [WARN] Prediction collapse detected!")
        
        # Per-dimension R2
        print(f"\n{'='*70}")
        print(f"PER-DIMENSION R2")
        print(f"{'='*70}")
        
        dim_r2s = [(dim, r2_score(targets[:, i], preds[:, i])) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)]
        dim_r2s.sort(key=lambda x: x[1], reverse=True)
        
        print(f"\n  {'Dimension':<30} {'R2':>10} {'Status':>10}")
        print(f"  {'-'*30} {'-'*10} {'-'*10}")
        for dim, dim_r2 in dim_r2s:
            status = "[GOOD]" if dim_r2 >= 0.3 else "[OK]" if dim_r2 >= 0.1 else "[WEAK]" if dim_r2 >= 0 else "[NEG]"
            print(f"  {dim:<30} {dim_r2:>+10.4f} {status:>10}")
        
        positive_dims = sum(1 for _, r2 in dim_r2s if r2 > 0)
        print(f"\n  Positive R2: {positive_dims}/19")
        
        # Hierarchy diagnostics (for models with hierarchy)
        if hasattr(model, 'beat_attention'):
            print(f"\n{'='*70}")
            print(f"HIERARCHY DIAGNOSTICS")
            print(f"{'='*70}")
            
            # Run one batch with diagnose=True
            batch = next(iter(val_loader))
            batch_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            if isinstance(batch_on_device['input_features'], PackedSequence):
                batch_on_device['input_features'] = PackedSequence(
                    batch['input_features'].data.to(device),
                    batch['input_features'].batch_sizes,
                    batch['input_features'].sorted_indices.to(device) if batch['input_features'].sorted_indices is not None else None,
                    batch['input_features'].unsorted_indices.to(device) if batch['input_features'].unsorted_indices is not None else None,
                )
            
            note_locations = {
                'beat': batch_on_device['note_locations_beat'],
                'measure': batch_on_device['note_locations_measure'],
                'voice': batch_on_device['note_locations_voice'],
            }
            
            with torch.no_grad():
                _ = model(batch_on_device['input_features'], note_locations, diagnose=True)
            
            # Contractor weight analysis
            if hasattr(model, 'note_contractor'):
                w = model.note_contractor.weight.data
                if w.shape[1] == 1024:  # baseline_beat
                    lstm_w = w[:, :512].abs().mean().item()
                    beat_w = w[:, 512:].abs().mean().item()
                    print(f"\n  Contractor weights:")
                    print(f"    LSTM branch:  {lstm_w:.4f}")
                    print(f"    Beat branch:  {beat_w:.4f}")
                    print(f"    Ratio (Beat/LSTM): {beat_w/lstm_w:.2f}x")
                    if beat_w < lstm_w * 0.1:
                        print(f"    [WARN] Contractor may be ignoring beat branch!")
                elif w.shape[1] == 1536:  # baseline_beat_measure
                    lstm_w = w[:, :512].abs().mean().item()
                    beat_w = w[:, 512:1024].abs().mean().item()
                    meas_w = w[:, 1024:].abs().mean().item()
                    print(f"\n  Contractor weights:")
                    print(f"    LSTM branch:    {lstm_w:.4f}")
                    print(f"    Beat branch:    {beat_w:.4f}")
                    print(f"    Measure branch: {meas_w:.4f}")
        
        # Comparison to baseline
        print(f"\n{'='*70}")
        print(f"COMPARISON TO BASELINE")
        print(f"{'='*70}")
        
        baseline_r2 = 0.1931
        hierarchy_gain = r2 - baseline_r2
        
        print(f"\n  {'Model':<30} {'R2':>10} {'Expected':>10}")
        print(f"  {'-'*30} {'-'*10} {'-'*10}")
        print(f"  {'Baseline (7-layer BiLSTM)':<30} {'+0.1931':>10} {'~0.19':>10}")
        
        expected = {"baseline_beat": "~0.25-0.30", "baseline_beat_measure": "~0.35-0.40", "han": "~0.40"}.get(LAST_MODEL_TYPE, "~0.19")
        print(f"  {LAST_MODEL_TYPE:<30} {r2:>+10.4f} {expected:>10}")
        print(f"  {'-'*30} {'-'*10} {'-'*10}")
        print(f"  {'Hierarchy Gain':<30} {hierarchy_gain:>+10.4f}")
        
        if hierarchy_gain > 0.15:
            print(f"\n  [GOOD] Significant hierarchy contribution!")
        elif hierarchy_gain > 0.05:
            print(f"\n  [PARTIAL] Some hierarchy contribution")
        elif hierarchy_gain > 0:
            print(f"\n  [WEAK] Minimal hierarchy contribution")
        else:
            print(f"\n  [NONE] No hierarchy contribution (may be hurting)")

print(f"\n{'='*70}")

## Step 9: Test Set Evaluation (Optional)

In [None]:
# Test set evaluation requires training all 4 folds
# For Phase 2 hierarchy isolation, we only train fold 2
print("Skipping test evaluation (Phase 2 trains single fold only)")
print("To evaluate on test set, train all 4 folds with full K-fold training")

## Step 10: Save to Google Drive

In [None]:
print("="*60)
print("SYNC TO GOOGLE DRIVE")
print("="*60)

if RCLONE_AVAILABLE:
    print(f"\nSyncing checkpoints...")
    subprocess.run(
        ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINT_PATH, '--progress'],
        capture_output=False
    )
    
    print(f"\nSyncing fold assignments...")
    subprocess.run(
        ['rclone', 'copy', str(FOLD_FILE), GDRIVE_DATA_PATH, '--progress'],
        capture_output=False
    )
    
    print(f"\nSync complete!")
    print(f"  Checkpoints: {GDRIVE_CHECKPOINT_PATH}")
else:
    print("rclone not available - skipping sync")

print(f"\n{'='*60}")
print("DONE")
print(f"{'='*60}")