# Piano Performance Evaluation - Model Training

## Loss Functions Enabled

- **CORAL**: Ordinal regression for rank-consistent predictions (0-100 scale -> 20 bins)
- **FDS**: Feature Distribution Smoothing for imbalanced targets
- **LDS**: Label Distribution Smoothing
- **Bootstrap Loss**: Handles noisy labels
- **Huber Loss**: Robust to outliers
- **Ranking Loss**: Pairwise ranking consistency
- **Contrastive Loss**: Cross-modal alignment (InfoNCE)

## Step 1: Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

import os
os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"

# Clone repository (if not already present)
if not os.path.exists('/tmp/crescendai'):
    !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai

%cd /tmp/crescendai/model
!git pull
!git log -1 --oneline

# Install package (torchaudio included for GPU-accelerated audio loading)
!uv pip install --system -e .

import torch
import pytorch_lightning as pl
print(f"\nPyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

In [None]:
!pip install tensorboard


#
# Reminder: Run this in terminal: rclone config
#

In [None]:
import os
from pathlib import Path
import subprocess

# Paths
CHECKPOINT_ROOT = '/tmp/checkpoints'
GDRIVE_CHECKPOINT_PATH = 'gdrive:crescendai_checkpoints/fusion_comparison'
GDRIVE_MERT_PATH = 'gdrive:MERT-v1-95M'

print("="*70)
print("SETUP: CHECKPOINTS AND MERT MODEL")
print("="*70)

# 1. Create checkpoint directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
print(f"\nCheckpoint directory: {CHECKPOINT_ROOT}")

# 2. Check if rclone is configured for Google Drive
print("\nChecking rclone configuration...")
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)

if 'gdrive:' in result.stdout:
    print("  rclone 'gdrive' remote: CONFIGURED")
    RCLONE_AVAILABLE = True
    
    # Try to restore existing checkpoints
    print("\nRestoring checkpoints from Google Drive (if any)...")
    subprocess.run(
        ['rclone', 'copy', GDRIVE_CHECKPOINT_PATH, CHECKPOINT_ROOT, '--progress'],
        capture_output=False
    )
    
    # List restored checkpoints
    for fusion_type in ['crossattn', 'gated', 'concat', 'audio_only', 'midi_only']:
        ckpt_dir = Path(f"{CHECKPOINT_ROOT}/{fusion_type}")
        if ckpt_dir.exists():
            ckpts = list(ckpt_dir.glob('*.ckpt'))
            if ckpts:
                print(f"  {fusion_type}: Restored {len(ckpts)} checkpoint(s)")
else:
    print("  rclone 'gdrive' remote: NOT CONFIGURED")
    print("  WARNING: Checkpoints will NOT be backed up!")
    RCLONE_AVAILABLE = False

# 3. MERT-95M model - download from Google Drive to HuggingFace cache
print("\n" + "="*70)
print("MERT-95M MODEL")
print("="*70)

HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
MERT_CACHE = HF_CACHE / "models--m-a-p--MERT-v1-95M"
MERT_SNAPSHOT = MERT_CACHE / "snapshots" / "main"

if MERT_SNAPSHOT.exists() and (MERT_SNAPSHOT / "model.safetensors").exists():
    print(f"MERT-95M already cached at: {MERT_SNAPSHOT}")
else:
    print("Downloading MERT-95M from Google Drive...")
    
    # Create directory structure
    MERT_SNAPSHOT.mkdir(parents=True, exist_ok=True)
    
    # Copy from Google Drive
    result = subprocess.run(
        ['rclone', 'copy', GDRIVE_MERT_PATH, str(MERT_SNAPSHOT), '--progress'],
        capture_output=False
    )
    
    # Verify download
    if (MERT_SNAPSHOT / "model.safetensors").exists():
        print(f"MERT-95M downloaded to: {MERT_SNAPSHOT}")
        model_size = (MERT_SNAPSHOT / "model.safetensors").stat().st_size / (1024**2)
        print(f"Model size: {model_size:.1f} MB")
    else:
        print("ERROR: Failed to download MERT-95M from Google Drive!")
        print(f"Make sure {GDRIVE_MERT_PATH} exists and contains model.safetensors")

print(f"\nGoogle Drive checkpoint path: {GDRIVE_CHECKPOINT_PATH}")
print(f"rclone available: {RCLONE_AVAILABLE}")


## Step 2: Extract Dataset

In [None]:
import tarfile
from pathlib import Path
import os
import subprocess

# Get the notebook's directory as the base for finding files
NOTEBOOK_DIR = Path.cwd()
DATA_ROOT = Path("/tmp/maestro_data")  # Extract to /tmp for speed

# Google Drive paths (files in root directory)
GDRIVE_DATASET_PATH = "gdrive:maestro_with_variance.tar.gz"
LOCAL_TARBALL_PATH = Path("/tmp/maestro_with_variance.tar.gz")

print(f"Working directory: {NOTEBOOK_DIR}")
print(f"Data will be extracted to: {DATA_ROOT}")

# Check if already extracted
if (DATA_ROOT / "audio").exists() and len(list((DATA_ROOT / "audio").glob("*.wav"))) > 100:
    print(f"\nDataset already extracted at {DATA_ROOT}")
    audio_files = list((DATA_ROOT / "audio").glob("*.wav"))
    midi_files = list((DATA_ROOT / "midi").glob("*.mid"))
    print(f"  Audio files: {len(audio_files):,}")
    print(f"  MIDI files: {len(midi_files):,}")
else:
    # Download from Google Drive (root directory)
    print("\nDownloading dataset from Google Drive...")
    print(f"Source: {GDRIVE_DATASET_PATH}")
    print(f"Destination: {LOCAL_TARBALL_PATH}")
    
    result = subprocess.run(
        ['rclone', 'copy', GDRIVE_DATASET_PATH, '/tmp/', '--progress'],
        capture_output=False
    )
    
    if LOCAL_TARBALL_PATH.exists():
        print(f"\nDownload complete! Size: {LOCAL_TARBALL_PATH.stat().st_size / (1024**3):.2f} GB")
        tarball_path = LOCAL_TARBALL_PATH
    else:
        # Fallback to local search
        print("\nGoogle Drive download failed, searching locally...")
        TARBALL_SEARCH_PATHS = [
            NOTEBOOK_DIR / "maestro_with_variance.tar.gz",
            Path("maestro_with_variance.tar.gz"),
            Path.home() / "maestro_with_variance.tar.gz",
            NOTEBOOK_DIR.parent / "maestro_with_variance.tar.gz",
        ]
        
        tarball_path = None
        for p in TARBALL_SEARCH_PATHS:
            if p.exists():
                tarball_path = p
                break
        
        if tarball_path is None:
            print("="*60)
            print("ERROR: maestro_with_variance.tar.gz not found!")
            print("="*60)
            print("\nTo fix:")
            print("  1. Upload to Google Drive root: gdrive:maestro_with_variance.tar.gz")
            print("  2. Or place in notebook directory")
            raise FileNotFoundError("Dataset not found")
    
    # Extract dataset
    DATA_ROOT.mkdir(parents=True, exist_ok=True)
    print(f"\nExtracting to {DATA_ROOT}...")
    with tarfile.open(tarball_path, "r:gz") as tar:
        tar.extractall(DATA_ROOT)
    print("Extraction complete!")
    
    # Clean up tarball to save space
    if LOCAL_TARBALL_PATH.exists():
        print(f"Removing tarball to save space...")
        LOCAL_TARBALL_PATH.unlink()

# Verify extraction
audio_files = list((DATA_ROOT / "audio").glob("*.wav"))
midi_files = list((DATA_ROOT / "midi").glob("*.mid"))
annotation_files = list((DATA_ROOT / "annotations").glob("*.jsonl"))

print(f"\nDataset contents:")
print(f"  Audio files: {len(audio_files):,}")
print(f"  MIDI files: {len(midi_files):,}")
print(f"  Annotation files: {len(annotation_files)}")

# Verify annotation structure
if annotation_files:
    import json as json_module
    with open(annotation_files[0]) as f:
        sample = json_module.loads(f.readline())
    print(f"\nSample annotation:")
    print(f"  Dimensions: {list(sample['labels'].keys())}")
    print(f"  Quality tier: {sample.get('quality_tier', 'N/A')}")
    print(f"  Quality score: {sample.get('quality_score', 'N/A')}")

# Store paths for later cells
DATASET_PATHS = {
    'train_path': str(DATA_ROOT / "annotations" / "train.jsonl"),
    'val_path': str(DATA_ROOT / "annotations" / "validation.jsonl"),
    'test_path': str(DATA_ROOT / "annotations" / "test.jsonl"),
    'audio_dir': str(DATA_ROOT / "audio"),
    'midi_dir': str(DATA_ROOT / "midi"),
}

print(f"\n{'='*60}")
print("DATASET PATHS (for training)")
print(f"{'='*60}")
for k, v in DATASET_PATHS.items():
    print(f"  {k}: {v}")

# ======================================================
# Load pre-trained MIDIBert checkpoint from Google Drive
# ======================================================

MIDI_PRETRAIN_CONFIG = {
    'output_dir': '/tmp/checkpoints/midi_pretrain',
    'gdrive_output': 'gdrive:crescendai_checkpoints/midi_pretrain',
    'hidden_size': 256,
    'num_layers': 6,
    'max_seq_length': 512,
}

output_dir = Path(MIDI_PRETRAIN_CONFIG['output_dir'])
encoder_checkpoint = output_dir / 'encoder_pretrained.pt'
best_checkpoint = output_dir / 'best.pt'

print(f"{'='*60}")
print("LOADING PRE-TRAINED MIDIBERT CHECKPOINT")
print(f"{'='*60}")

# Check if already downloaded
if encoder_checkpoint.exists():
    print(f"Checkpoint already exists: {encoder_checkpoint}")
else:
    # Download from Google Drive
    print(f"Downloading from: {MIDI_PRETRAIN_CONFIG['gdrive_output']}")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    result = subprocess.run(
        ['rclone', 'copy', MIDI_PRETRAIN_CONFIG['gdrive_output'], str(output_dir), '--progress'],
        capture_output=False
    )
    
    if not encoder_checkpoint.exists() and not best_checkpoint.exists():
        raise FileNotFoundError(
            f"Failed to download MIDIBert checkpoint from Google Drive.\n"
            f"Expected: {MIDI_PRETRAIN_CONFIG['gdrive_output']}/encoder_pretrained.pt"
        )
    
    print("Download complete!")

# Set checkpoint path
if encoder_checkpoint.exists():
    MIDI_PRETRAINED_CHECKPOINT_PATH = str(encoder_checkpoint)
elif best_checkpoint.exists():
    MIDI_PRETRAINED_CHECKPOINT_PATH = str(best_checkpoint)
else:
    raise FileNotFoundError("No MIDIBert checkpoint found")

# Show training stats if available
log_path = output_dir / 'training_log.json'
if log_path.exists():
    import json as json_module
    with open(log_path) as f:
        log = json_module.load(f)
    final_epoch = log[-1]
    print(f"\nPre-training stats (from GiantMIDI-Piano):")
    print(f"  Epochs: {final_epoch['epoch']}")
    print(f"  Final val loss: {final_epoch['val_loss']:.4f}")
    print(f"  Pitch val loss: {final_epoch['val_losses']['pitch']:.4f}")

print(f"\n{'='*60}")
print(f"MIDIBERT CHECKPOINT: {MIDI_PRETRAINED_CHECKPOINT_PATH}")
print(f"{'='*60}")


## Step 3: Training Configuration

In [None]:
# Training configuration

# Enable Tensor Core optimization for A100
import torch
torch.set_float32_matmul_precision('medium')

CONFIG = {
    # =========================================================================
    # DATA PATHS (from Step 2 extraction)
    # =========================================================================
    'train_path': DATASET_PATHS['train_path'],
    'val_path': DATASET_PATHS['val_path'],
    'test_path': DATASET_PATHS['test_path'],
    
    # =========================================================================
    # DIMENSIONS (8 total)
    # =========================================================================
    'dimensions': [
        'note_accuracy', 'rhythmic_stability', 'articulation_clarity', 'pedal_technique',
        'tone_quality', 'dynamic_range', 'musical_expression', 'overall_interpretation'
    ],
    
    # =========================================================================
    # MODEL ARCHITECTURE
    # =========================================================================
    'audio_dim': 768,       # MERT-95M output dimension
    'midi_dim': 256,        # MIDIBert output dimension
    'shared_dim': 512,      # Projection space dimension
    'use_projection': True, # Use projection heads before fusion
    
    # MIDIBert pretraining - use checkpoint from Step 2b
    'midi_pretrained_checkpoint': MIDI_PRETRAINED_CHECKPOINT_PATH if 'MIDI_PRETRAINED_CHECKPOINT_PATH' in dir() else None,
    
    # =========================================================================
    # LOSS WEIGHTS
    # =========================================================================
    'mse_weight': 1.0,
    'ranking_weight': 0.2,
    'contrastive_weight': 0.1,
    
    # =========================================================================
    # BASE LOSS FUNCTION
    # =========================================================================
    'base_loss': 'huber',      # 'mse', 'huber', or 'mae' - Huber is robust to outliers
    'huber_delta': 1.0,
    
    # =========================================================================
    # LABEL DISTRIBUTION SMOOTHING (LDS)
    # Handles imbalanced label distribution
    # =========================================================================
    'lds_enabled': True,
    'lds_num_bins': 100,
    'lds_sigma': 2.0,
    'lds_reweight_scale': 1.0,
    
    # =========================================================================
    # FEATURE DISTRIBUTION SMOOTHING (FDS) - NEW
    # Calibrates features across target bins for better generalization
    # =========================================================================
    'fds_enabled': False,  # Disabled to reduce CPU memory
    'fds_num_bins': 100,
    'fds_momentum': 0.9,
    'fds_kernel_sigma': 2.0,
    'fds_start_epoch': 2,      # Start after model warms up
    
    # =========================================================================
    # CORAL ORDINAL REGRESSION - NEW
    # Converts regression to ordinal classification for rank consistency
    # =========================================================================
    'coral_enabled': True,
    'coral_num_classes': 20,   # 20 bins = 5-point resolution on 0-100 scale
    'coral_weight': 0.3,       # Weight for CORAL loss component
    
    # =========================================================================
    # BOOTSTRAP LOSS (handles noisy labels)
    # =========================================================================
    'bootstrap_enabled': True,
    'bootstrap_beta': 0.8,          # 80% label, 20% model prediction
    'bootstrap_warmup_epochs': 2,   # Train with pure labels first
    
    # =========================================================================
    # MODALITY DROPOUT (prevents modality collapse)
    # =========================================================================
    'modality_dropout': {
        'enabled': True,
        'audio_prob': 0.15,
        'midi_prob': 0.15,
    },
    
    # =========================================================================
    # TRAINING HYPERPARAMETERS
    # =========================================================================
    'epochs': 5,
    'batch_size': 8,
    'backbone_lr': 5e-6,
    'heads_lr': 1e-4,
    'warmup_steps': 500,
    
    # =========================================================================
    # STAGED UNFREEZING
    # =========================================================================
    'staged_unfreezing': {
        'enabled': True,
        'schedule': [
            {'epoch': 0, 'freeze': ['audio_encoder', 'midi_encoder'], 'unfreeze': ['projection']},
            {'epoch': 3, 'unfreeze': ['audio_encoder.top_4', 'midi_encoder.top_2'], 'lr_scale': 0.1},
        ]
    },
    
    # =========================================================================
    # FUSION TYPES TO COMPARE
    # =========================================================================
    'fusion_types': ['crossattn', 'gated', 'concat'],
    
    # =========================================================================
    # PHASE 2 SUCCESS CRITERIA
    # =========================================================================
    'fusion_improvement_target': 10.0,  # Fusion must beat single-modal by >= 10%
}

# Print configuration
print("="*70)
print("TRAINING CONFIGURATION")
print("="*70)

print("\nDATA PATHS:")
print(f"  Train: {CONFIG['train_path']}")
print(f"  Val: {CONFIG['val_path']}")
print(f"  Test: {CONFIG['test_path']}")

print("\nMODEL:")
print(f"  Audio dim: {CONFIG['audio_dim']}")
print(f"  MIDI dim: {CONFIG['midi_dim']}")
print(f"  Shared dim: {CONFIG['shared_dim']}")
print(f"  MIDIBert checkpoint: {CONFIG['midi_pretrained_checkpoint']}")

print("\nLOSS FUNCTIONS:")
print(f"  Base loss: {CONFIG['base_loss']}")
print(f"  LDS: {'enabled' if CONFIG['lds_enabled'] else 'disabled'}")
print(f"  FDS: {'enabled' if CONFIG['fds_enabled'] else 'disabled'}")
print(f"  CORAL: {'enabled' if CONFIG['coral_enabled'] else 'disabled'}")
print(f"  Bootstrap: {'enabled' if CONFIG['bootstrap_enabled'] else 'disabled'}")

print("\nTRAINING:")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Modality dropout: {CONFIG['modality_dropout']['audio_prob']:.0%} audio, {CONFIG['modality_dropout']['midi_prob']:.0%} MIDI")

print("\nFUSION TYPES TO COMPARE:")
for ft in CONFIG['fusion_types']:
    print(f"  - {ft}")

## Step 4: Create DataLoaders

In [None]:
from src.data.dataset import create_dataloaders

# Prepare modality dropout config
modality_dropout_config = None
if CONFIG['modality_dropout']['enabled']:
    modality_dropout_config = {
        'audio_prob': CONFIG['modality_dropout']['audio_prob'],
        'midi_prob': CONFIG['modality_dropout']['midi_prob'],
    }
    print(f"Modality dropout enabled: audio={modality_dropout_config['audio_prob']}, midi={modality_dropout_config['midi_prob']}")

train_loader, val_loader, test_loader = create_dataloaders(
    train_annotation_path=CONFIG['train_path'],
    val_annotation_path=CONFIG['val_path'],
    test_annotation_path=CONFIG['test_path'],
    dimension_names=CONFIG['dimensions'],
    batch_size=CONFIG['batch_size'],
    num_workers=4,
    augmentation_config=None,  # Disable augmentation for clean comparison
    modality_dropout_config=modality_dropout_config,
    audio_sample_rate=24000,
    max_audio_length=240000,
    max_midi_events=512,
)

print(f"Train samples: {len(train_loader.dataset):,}")
print(f"Val samples: {len(val_loader.dataset):,}")
print(f"Test samples: {len(test_loader.dataset):,}")

## Step 5: Train

In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

from src.models.lightning_module import PerformanceEvaluationModel
from src.utils.memory_profiler import MemoryProfilerCallback, log_memory
from src.callbacks.unfreezing import StagedUnfreezingCallback
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pathlib import Path
import gc

# Collect training labels for LDS fitting (if enabled)
if CONFIG['lds_enabled']:
    print("Collecting training labels for LDS fitting...")
    all_labels = []
    for batch in train_loader:
        all_labels.append(batch['labels'])
    all_train_labels = torch.cat(all_labels, dim=0)
    print(f"Collected {len(all_train_labels):,} training labels")

# Handle MIDIBert pretrained checkpoint (download from Google Drive if needed)
midi_pretrained_local = None
if CONFIG.get('midi_pretrained_checkpoint'):
    midi_ckpt_path = CONFIG['midi_pretrained_checkpoint']
    if midi_ckpt_path.startswith('gdrive:'):
        # Copy from Google Drive to local
        local_path = '/tmp/midi_pretrained.pt'
        print(f"Copying MIDIBert checkpoint from Google Drive: {midi_ckpt_path}")
        !rclone copy {midi_ckpt_path} /tmp/ --progress
        ckpt_name = midi_ckpt_path.split('/')[-1]
        midi_pretrained_local = f'/tmp/{ckpt_name}'
        if Path(midi_pretrained_local).exists():
            print(f"MIDIBert checkpoint loaded: {midi_pretrained_local}")
        else:
            print(f"WARNING: MIDIBert checkpoint not found at {midi_pretrained_local}")
            midi_pretrained_local = None
    else:
        midi_pretrained_local = midi_ckpt_path
        if not Path(midi_pretrained_local).exists():
            print(f"WARNING: MIDIBert checkpoint not found at {midi_pretrained_local}")
            midi_pretrained_local = None

# Store trained models (will be populated by individual training cells)
trained_models = {}

def find_latest_checkpoint(checkpoint_dir):
    """Find the latest checkpoint in a directory."""
    ckpt_dir = Path(checkpoint_dir)
    if not ckpt_dir.exists():
        return None
    ckpts = list(ckpt_dir.glob('*.ckpt'))
    if not ckpts:
        return None
    last_ckpt = ckpt_dir / 'last.ckpt'
    if last_ckpt.exists():
        return str(last_ckpt)
    return str(sorted(ckpts, key=lambda x: x.stat().st_mtime)[-1])

def train_single_modal_model(modality):
    """Train a single-modal model (audio-only or MIDI-only) for Phase 2 baseline comparison."""
    model_name = f"{modality}_only"
    print("="*70)
    print(f"TRAINING: {model_name.upper()} BASELINE")
    print("="*70)
    
    # Check for existing checkpoint to resume
    checkpoint_dir = Path(f'{CHECKPOINT_ROOT}/{model_name}')
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    resume_ckpt = find_latest_checkpoint(checkpoint_dir)
    
    if resume_ckpt:
        print(f"Found checkpoint to resume: {resume_ckpt}")
    
    # Set dimensions based on modality
    if modality == 'audio':
        audio_dim = CONFIG['audio_dim']
        midi_dim = 0
        print("Audio-only mode: MERT encoder only")
    elif modality == 'midi':
        audio_dim = 0
        midi_dim = CONFIG['midi_dim']
        print("MIDI-only mode: MIDIBert encoder only")
        if midi_pretrained_local:
            print(f"Using pretrained MIDIBert: {midi_pretrained_local}")
    else:
        raise ValueError(f"Unknown modality: {modality}. Use 'audio' or 'midi'.")
    
    # Create model with CORAL and FDS support
    model = PerformanceEvaluationModel(
        audio_dim=audio_dim,
        midi_dim=midi_dim,
        shared_dim=CONFIG['shared_dim'],
        aggregator_dim=512,
        num_dimensions=len(CONFIG['dimensions']),
        dimension_names=CONFIG['dimensions'],
        modality=modality,
        fusion_type='gated',
        use_projection=CONFIG['use_projection'],
        midi_pretrained_checkpoint=midi_pretrained_local if modality == 'midi' else None,
        # Loss weights
        mse_weight=CONFIG['mse_weight'],
        ranking_weight=CONFIG['ranking_weight'],
        contrastive_weight=0,  # No contrastive loss for single-modal
        # Base loss
        base_loss=CONFIG['base_loss'],
        huber_delta=CONFIG['huber_delta'],
        # LDS
        lds_enabled=CONFIG['lds_enabled'],
        lds_num_bins=CONFIG['lds_num_bins'],
        lds_sigma=CONFIG['lds_sigma'],
        lds_reweight_scale=CONFIG['lds_reweight_scale'],
        # FDS (NEW)
        fds_enabled=CONFIG['fds_enabled'],
        fds_num_bins=CONFIG['fds_num_bins'],
        fds_momentum=CONFIG['fds_momentum'],
        fds_kernel_sigma=CONFIG['fds_kernel_sigma'],
        fds_start_epoch=CONFIG['fds_start_epoch'],
        # CORAL (NEW)
        coral_enabled=CONFIG['coral_enabled'],
        coral_num_classes=CONFIG['coral_num_classes'],
        coral_weight=CONFIG['coral_weight'],
        # Bootstrap
        bootstrap_enabled=CONFIG['bootstrap_enabled'],
        bootstrap_beta=CONFIG['bootstrap_beta'],
        bootstrap_warmup_epochs=CONFIG['bootstrap_warmup_epochs'],
        # Training
        backbone_lr=CONFIG['backbone_lr'],
        heads_lr=CONFIG['heads_lr'],
        warmup_steps=CONFIG['warmup_steps'],
        max_epochs=CONFIG['epochs'],
        gradient_checkpointing=True,
    )
    
    # Fit LDS if enabled
    if CONFIG['lds_enabled']:
        model.fit_lds(all_train_labels)
    
    # Callbacks
    callbacks = [
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename=f'{model_name}-{{epoch:02d}}-{{val_loss:.4f}}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True,
        ),
        EarlyStopping(monitor='val_loss', patience=3, mode='min'),
        LearningRateMonitor(logging_interval='step'),
        MemoryProfilerCallback(log_every_n_steps=100, log_to_file=f'/tmp/{model_name}_memory.log'),
    ]
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=CONFIG['epochs'],
        precision='16-mixed',
        accelerator='auto',
        devices='auto',
        callbacks=callbacks,
        logger=TensorBoardLogger(save_dir='logs', name=model_name),
        log_every_n_steps=50,
        gradient_clip_val=1.0,
        accumulate_grad_batches=2,
        val_check_interval=0.5,
    )
    
    # Train
    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)
    
    # Store results
    result = {
        'best_checkpoint': callbacks[0].best_model_path,
        'best_val_loss': float(callbacks[0].best_model_score) if callbacks[0].best_model_score else None,
    }
    trained_models[model_name] = result
    
    print(f"\nBest checkpoint: {result['best_checkpoint']}")
    print(f"Best val loss: {result['best_val_loss']:.4f}" if result['best_val_loss'] else "N/A")
    
    # Sync to Google Drive
    print(f"\nSyncing to Google Drive...")
    !rclone copy {checkpoint_dir} {GDRIVE_CHECKPOINT_PATH}/{model_name} --progress
    print("Sync complete!")
    
    # Cleanup
    del model, trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    return result

def train_fusion_model(fusion_type):
    """Train a single fusion model with CORAL and FDS support."""
    print("="*70)
    print(f"TRAINING: {fusion_type.upper()} FUSION")
    print("="*70)
    
    # Check for existing checkpoint to resume
    checkpoint_dir = Path(f'{CHECKPOINT_ROOT}/{fusion_type}')
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    resume_ckpt = find_latest_checkpoint(checkpoint_dir)
    
    if resume_ckpt:
        print(f"Found checkpoint to resume: {resume_ckpt}")
    
    # Report pretrained MIDIBert status
    if midi_pretrained_local:
        print(f"Using pretrained MIDIBert: {midi_pretrained_local}")
    else:
        print("Training MIDIBert from scratch (no pretrained checkpoint)")
    
    # Report CORAL and FDS status
    print(f"CORAL: {'enabled' if CONFIG['coral_enabled'] else 'disabled'}")
    print(f"FDS: {'enabled' if CONFIG['fds_enabled'] else 'disabled'}")
    
    # Create model with CORAL and FDS support
    model = PerformanceEvaluationModel(
        audio_dim=CONFIG['audio_dim'],
        midi_dim=CONFIG['midi_dim'],
        shared_dim=CONFIG['shared_dim'],
        aggregator_dim=512,
        num_dimensions=len(CONFIG['dimensions']),
        dimension_names=CONFIG['dimensions'],
        fusion_type=fusion_type,
        use_projection=CONFIG['use_projection'],
        midi_pretrained_checkpoint=midi_pretrained_local,
        # Loss weights
        mse_weight=CONFIG['mse_weight'],
        ranking_weight=CONFIG['ranking_weight'],
        contrastive_weight=CONFIG['contrastive_weight'],
        # Base loss
        base_loss=CONFIG['base_loss'],
        huber_delta=CONFIG['huber_delta'],
        # LDS
        lds_enabled=CONFIG['lds_enabled'],
        lds_num_bins=CONFIG['lds_num_bins'],
        lds_sigma=CONFIG['lds_sigma'],
        lds_reweight_scale=CONFIG['lds_reweight_scale'],
        # FDS (NEW)
        fds_enabled=CONFIG['fds_enabled'],
        fds_num_bins=CONFIG['fds_num_bins'],
        fds_momentum=CONFIG['fds_momentum'],
        fds_kernel_sigma=CONFIG['fds_kernel_sigma'],
        fds_start_epoch=CONFIG['fds_start_epoch'],
        # CORAL (NEW)
        coral_enabled=CONFIG['coral_enabled'],
        coral_num_classes=CONFIG['coral_num_classes'],
        coral_weight=CONFIG['coral_weight'],
        # Bootstrap
        bootstrap_enabled=CONFIG['bootstrap_enabled'],
        bootstrap_beta=CONFIG['bootstrap_beta'],
        bootstrap_warmup_epochs=CONFIG['bootstrap_warmup_epochs'],
        # Training
        backbone_lr=CONFIG['backbone_lr'],
        heads_lr=CONFIG['heads_lr'],
        warmup_steps=CONFIG['warmup_steps'],
        max_epochs=CONFIG['epochs'],
        gradient_checkpointing=True,
    )
    
    # Fit LDS if enabled
    if CONFIG['lds_enabled']:
        model.fit_lds(all_train_labels)
    
    # Callbacks
    callbacks = [
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename=f'{fusion_type}-{{epoch:02d}}-{{val_loss:.4f}}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True,
        ),
        EarlyStopping(monitor='val_loss', patience=3, mode='min'),
        LearningRateMonitor(logging_interval='step'),
        MemoryProfilerCallback(log_every_n_steps=100, log_to_file=f'/tmp/{fusion_type}_memory.log'),
    ]
    
    if CONFIG['staged_unfreezing']['enabled']:
        callbacks.append(StagedUnfreezingCallback(
            schedule=CONFIG['staged_unfreezing']['schedule'],
            verbose=True,
        ))
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=CONFIG['epochs'],
        precision='16-mixed',
        accelerator='auto',
        devices='auto',
        callbacks=callbacks,
        logger=TensorBoardLogger(save_dir='logs', name=fusion_type),
        log_every_n_steps=50,
        gradient_clip_val=1.0,
        accumulate_grad_batches=2,
        val_check_interval=0.5,
    )
    
    # Train
    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)
    
    # Store results
    result = {
        'best_checkpoint': callbacks[0].best_model_path,
        'best_val_loss': float(callbacks[0].best_model_score) if callbacks[0].best_model_score else None,
    }
    trained_models[fusion_type] = result
    
    print(f"\nBest checkpoint: {result['best_checkpoint']}")
    print(f"Best val loss: {result['best_val_loss']:.4f}" if result['best_val_loss'] else "N/A")
    
    # Sync to Google Drive
    print(f"\nSyncing to Google Drive...")
    !rclone copy {checkpoint_dir} {GDRIVE_CHECKPOINT_PATH}/{fusion_type} --progress
    print("Sync complete!")
    
    # Cleanup
    del model, trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    return result

print("="*70)
print("TRAINING FUNCTIONS READY")
print("="*70)
print("\nAvailable functions:")
print("  - train_single_modal_model(modality): Train audio-only or MIDI-only baseline")
print("  - train_fusion_model(fusion_type): Train fusion model (crossattn, gated, concat)")

print("\nNew features enabled:")
print(f"  - CORAL ordinal regression: {CONFIG['coral_enabled']}")
print(f"  - FDS feature smoothing: {CONFIG['fds_enabled']}")
print(f"  - LDS label smoothing: {CONFIG['lds_enabled']}")
print(f"  - Bootstrap loss: {CONFIG['bootstrap_enabled']}")

if midi_pretrained_local:
    print(f"\nMIDIBert pretrained checkpoint: {midi_pretrained_local}")
else:
    print("\nMIDIBert will train from scratch (no pretrained checkpoint)")

print("\nRun the cells below to train each model.")

In [None]:
# Preflight

import subprocess
import sys

print("=" * 70)
print("PREFLIGHT CHECK")
print("=" * 70)

# Run fast dev run
result = subprocess.run(
    [
        sys.executable, "train.py",
        "--config", "configs/experiment.yaml",
        "--fast-dev-run"
    ],
    cwd="/tmp/crescendai/model",
    capture_output=True,
    text=True
)

# Print output
print(result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

# Check result
if result.returncode == 0:
    print("=" * 70)
    print("PREFLIGHT CHECK PASSED - Ready for full training")
    print("=" * 70)
else:
    print("=" * 70)
    print("PREFLIGHT CHECK FAILED - Fix errors before training")
    print("=" * 70)
    raise RuntimeError("Preflight check failed")

### 5a. Train Audio-Only Baseline (Phase 2)

Train audio-only model using MERT encoder only. This establishes a baseline for comparing fusion performance.

In [None]:
train_single_modal_model('audio')

### 5b. Train MIDI-Only Baseline (Phase 2)

Train MIDI-only model using MIDIBert encoder only. This establishes a baseline for comparing fusion performance.

In [None]:
train_single_modal_model('midi')

### 5c. Train CrossAttention Fusion

In [None]:
train_fusion_model('crossattn')

In [None]:
train_fusion_model('gated')

In [None]:
train_fusion_model('concat')

## Step 5f: Phase 2 Gate Check

Evaluate all 5 models and check if fusion beats single-modal baselines by >= 10%.

**GO Criteria:**
- Best fusion model r > best single-modal r by >= 10%
- Models learn quality (higher scores for pristine vs degraded)

**NO-GO:** Debug fusion architecture before proceeding to Phase 3.

In [None]:
import numpy as np
from scipy import stats

print("="*80)
print("PHASE 2 GATE CHECK: EVALUATING ALL 5 MODELS")
print("="*80)

# Define all model types
single_modal_types = ['audio_only', 'midi_only']
fusion_types = ['crossattn', 'gated', 'concat']
all_model_types = single_modal_types + fusion_types

# Load models and evaluate
phase2_results = {}

for model_type in all_model_types:
    print(f"\nEvaluating {model_type}...")
    
    # Find checkpoint
    if model_type in trained_models:
        ckpt_path = trained_models[model_type]['best_checkpoint']
    else:
        ckpt_dir = Path(f'{CHECKPOINT_ROOT}/{model_type}')
        ckpts = list(ckpt_dir.glob('*.ckpt'))
        ckpts = [c for c in ckpts if c.name != 'last.ckpt']
        if ckpts:
            ckpt_path = str(sorted(ckpts)[-1])
        else:
            print(f"  No checkpoint found, skipping...")
            continue
    
    if not ckpt_path or not Path(ckpt_path).exists():
        print(f"  Checkpoint not found: {ckpt_path}")
        continue
    
    print(f"  Loading: {ckpt_path}")
    model = PerformanceEvaluationModel.load_from_checkpoint(ckpt_path)
    model.eval()
    model = model.cuda()
    
    # Evaluate
    trainer = pl.Trainer(accelerator='auto', devices='auto', precision=16, logger=False)
    test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
    phase2_results[model_type] = test_results[0]
    
    del model
    torch.cuda.empty_cache()

# Calculate mean Pearson r for each model
mean_rs = {}
for model_type in all_model_types:
    if model_type in phase2_results:
        mean_r = np.mean([phase2_results[model_type].get(f'test_pearson_{d}', 0) 
                         for d in CONFIG['dimensions']])
        mean_rs[model_type] = mean_r

# Print comparison table
print("\n" + "="*80)
print("PHASE 2 RESULTS: MODEL COMPARISON")
print("="*80)
print(f"\n{'Model':<15} {'Mean Pearson r':<15} {'Type'}")
print("-"*50)

for model_type in all_model_types:
    if model_type in mean_rs:
        model_kind = "Single-modal" if model_type in single_modal_types else "Fusion"
        print(f"{model_type:<15} {mean_rs[model_type]:>14.4f} {model_kind}")

# Phase 2 Gate Check
print("\n" + "="*80)
print("PHASE 2 GATE CHECK")
print("="*80)

# Best single-modal
single_modal_rs = {k: v for k, v in mean_rs.items() if k in single_modal_types}
best_single_modal = max(single_modal_rs, key=single_modal_rs.get) if single_modal_rs else None
best_single_r = single_modal_rs[best_single_modal] if best_single_modal else 0

# Best fusion
fusion_rs = {k: v for k, v in mean_rs.items() if k in fusion_types}
best_fusion = max(fusion_rs, key=fusion_rs.get) if fusion_rs else None
best_fusion_r = fusion_rs[best_fusion] if best_fusion else 0

# Calculate improvement
if best_single_r > 0:
    improvement = ((best_fusion_r - best_single_r) / best_single_r) * 100
else:
    improvement = 0

print(f"\nBest single-modal: {best_single_modal} (r = {best_single_r:.4f})")
print(f"Best fusion: {best_fusion} (r = {best_fusion_r:.4f})")
print(f"Improvement: {improvement:+.1f}%")
print(f"Target: >= {CONFIG['fusion_improvement_target']:.0f}%")

# Gate decision
PHASE2_PASSED = improvement >= CONFIG['fusion_improvement_target']

if PHASE2_PASSED:
    print(f"\n{'='*40}")
    print("PHASE 2 GATE: PASS")
    print(f"{'='*40}")
    print(f"{best_fusion} fusion beats {best_single_modal} by {improvement:.1f}%")
    print("-> Proceed to Phase 3: Contrastive Pre-training")
else:
    print(f"\n{'='*40}")
    print("PHASE 2 GATE: FAIL")
    print(f"{'='*40}")
    print(f"Fusion improvement ({improvement:.1f}%) < target ({CONFIG['fusion_improvement_target']:.0f}%)")
    print("-> Debug fusion architecture before proceeding")
    print("Suggestions:")
    print("  1. Check cross-modal alignment scores")
    print("  2. Try different fusion types")
    print("  3. Adjust learning rates")

# Store for later use
phase2_gate_result = {
    'passed': PHASE2_PASSED,
    'best_single_modal': best_single_modal,
    'best_single_r': best_single_r,
    'best_fusion': best_fusion,
    'best_fusion_r': best_fusion_r,
    'improvement': improvement,
}

## Step 6: Detailed Analysis of Phase 2 Models

Analyze the 5 Phase 2 models in detail. This section:
- Uses `phase2_results` from the gate check (no redundant evaluation)
- Provides per-dimension breakdown for all 5 models
- Shows fusion diagnostics (alignment, gate values)
- Runs statistical significance tests

**Run this AFTER Phase 2 Gate Check (Step 5f).**
Can be run while Phase 3/4 are training or after everything completes.

In [None]:
# Step 6a: Per-Dimension Analysis (All 5 Models)

import numpy as np

print("="*90)
print("PER-DIMENSION PEARSON CORRELATION (ALL 5 MODELS)")
print("="*90)

# Check that phase2_results exists
if 'phase2_results' not in dir() or not phase2_results:
    raise RuntimeError("phase2_results not found! Run Phase 2 Gate Check (cell 26) first.")

# All model types
single_modal_types = ['audio_only', 'midi_only']
fusion_types = CONFIG['fusion_types']
all_model_types = single_modal_types + fusion_types

# Build header
header = f"{'Dimension':<25}"
for mt in all_model_types:
    if mt in phase2_results:
        header += f" {mt[:10]:>10}"
header += " {'Best':>10}"
print(header)
print("-"*90)

# Per-dimension results
dimension_winners = {}
for dim in CONFIG['dimensions']:
    row = f"{dim:<25}"
    dim_scores = {}
    
    for mt in all_model_types:
        if mt in phase2_results:
            r = phase2_results[mt].get(f'test_pearson_{dim}', 0)
            dim_scores[mt] = r
            row += f" {r:>10.3f}"
    
    # Find best for this dimension
    if dim_scores:
        best_mt = max(dim_scores, key=dim_scores.get)
        dimension_winners[dim] = best_mt
        row += f" {best_mt[:10]:>10}"
    
    print(row)

print("-"*90)

# Summary: which model wins most dimensions
print("\n" + "="*90)
print("DIMENSION WINNERS SUMMARY")
print("="*90)

from collections import Counter
winner_counts = Counter(dimension_winners.values())
for mt, count in winner_counts.most_common():
    pct = 100 * count / len(CONFIG['dimensions'])
    print(f"  {mt:<15}: {count}/{len(CONFIG['dimensions'])} dimensions ({pct:.0f}%)")

# Overall mean scores (same as gate check, for reference)
print("\n" + "="*90)
print("OVERALL PERFORMANCE (from Phase 2 Gate Check)")
print("="*90)
print(f"{'Model':<15} {'Mean r':>10} {'Mean MAE':>12} {'Type':<12}")
print("-"*50)

for mt in all_model_types:
    if mt in phase2_results:
        mean_r = np.mean([phase2_results[mt].get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])
        mean_mae = np.mean([phase2_results[mt].get(f'test_mae_{d}', 0) for d in CONFIG['dimensions']])
        model_kind = "Single-modal" if mt in single_modal_types else "Fusion"
        print(f"{mt:<15} {mean_r:>10.4f} {mean_mae:>12.2f} {model_kind:<12}")

In [None]:
# Step 6b: Fusion Diagnostics
# Check cross-modal alignment, gate values, feature diversity

print("="*90)
print("FUSION MODEL DIAGNOSTICS")
print("="*90)

for ft in CONFIG['fusion_types']:
    if ft not in phase2_results:
        print(f"\n{ft.upper()}: No results available")
        continue
        
    print(f"\n{ft.upper()}:")
    
    # Cross-modal alignment (how well audio and MIDI embeddings align)
    align = phase2_results[ft].get('test_cross_modal_alignment', None)
    if align is not None:
        align_status = "GOOD" if align > 0.5 else "LOW - consider more contrastive training"
        print(f"  Cross-modal alignment: {align:.4f} ({align_status})")
    else:
        print(f"  Cross-modal alignment: Not logged")
    
    # Gate values (for gated fusion - shows audio vs MIDI contribution)
    if ft == 'gated':
        gate_mean = phase2_results[ft].get('test_gate_mean', None)
        if gate_mean is not None:
            if gate_mean > 0.6:
                gate_status = "Audio-dominant"
            elif gate_mean < 0.4:
                gate_status = "MIDI-dominant"
            else:
                gate_status = "Balanced"
            print(f"  Gate mean: {gate_mean:.4f} ({gate_status})")
        else:
            print(f"  Gate mean: Not logged")
    
    # Feature diversity (prevents mode collapse)
    audio_div = phase2_results[ft].get('test_audio_diversity', None)
    midi_div = phase2_results[ft].get('test_midi_diversity', None)
    if audio_div is not None:
        print(f"  Audio feature diversity: {audio_div:.4f}")
    if midi_div is not None:
        print(f"  MIDI feature diversity: {midi_div:.4f}")
    
    # Loss components (if logged)
    mse_loss = phase2_results[ft].get('test_mse_loss', None)
    ranking_loss = phase2_results[ft].get('test_ranking_loss', None)
    coral_loss = phase2_results[ft].get('test_coral_loss', None)
    
    if any([mse_loss, ranking_loss, coral_loss]):
        print(f"  Loss breakdown:")
        if mse_loss is not None:
            print(f"    MSE/Huber: {mse_loss:.4f}")
        if ranking_loss is not None:
            print(f"    Ranking: {ranking_loss:.4f}")
        if coral_loss is not None:
            print(f"    CORAL: {coral_loss:.4f}")

# Compare audio-only vs MIDI-only to understand modality contribution
print("\n" + "="*90)
print("MODALITY CONTRIBUTION ANALYSIS")
print("="*90)

audio_only_r = np.mean([phase2_results.get('audio_only', {}).get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])
midi_only_r = np.mean([phase2_results.get('midi_only', {}).get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])

print(f"\nAudio-only mean r: {audio_only_r:.4f}")
print(f"MIDI-only mean r:  {midi_only_r:.4f}")

if audio_only_r > 0 and midi_only_r > 0:
    ratio = audio_only_r / midi_only_r
    if ratio > 1.5:
        print(f"\nAudio is {ratio:.1f}x stronger than MIDI")
        print("  -> Audio encoder (MERT) is primary contributor")
        print("  -> MIDI encoder may need more pre-training or different architecture")
    elif ratio < 0.67:
        print(f"\nMIDI is {1/ratio:.1f}x stronger than audio")
        print("  -> MIDI encoder is primary contributor")
        print("  -> This is unusual - check audio data quality")
    else:
        print(f"\nModalities are balanced (ratio: {ratio:.2f})")
        print("  -> Both encoders contribute meaningfully")
        print("  -> Fusion should provide complementary information")

In [None]:
# Step 6c: Statistical Significance Tests
# Compare best fusion vs best single-modal with paired t-test

from scipy import stats

print("="*90)
print("STATISTICAL SIGNIFICANCE ANALYSIS")
print("="*90)

# Get best models from Phase 2 gate check
if 'phase2_gate_result' not in dir():
    raise RuntimeError("phase2_gate_result not found! Run Phase 2 Gate Check first.")

best_fusion = phase2_gate_result['best_fusion']
best_single = phase2_gate_result['best_single_modal']

print(f"\nComparing: {best_fusion} (fusion) vs {best_single} (single-modal)")
print("Loading models to collect predictions for paired t-test...")

# Load both models and collect predictions
comparison_preds = {}

for model_type in [best_fusion, best_single]:
    # Find checkpoint
    if model_type in trained_models:
        ckpt_path = trained_models[model_type]['best_checkpoint']
    else:
        ckpt_dir = Path(f'{CHECKPOINT_ROOT}/{model_type}')
        ckpts = [c for c in ckpt_dir.glob('*.ckpt') if c.name != 'last.ckpt']
        ckpt_path = str(sorted(ckpts)[-1]) if ckpts else None
    
    if not ckpt_path or not Path(ckpt_path).exists():
        print(f"  Checkpoint not found for {model_type}, skipping significance test")
        continue
    
    print(f"  Loading {model_type}...")
    model = PerformanceEvaluationModel.load_from_checkpoint(ckpt_path)
    model.eval()
    model = model.cuda()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in test_loader:
            audio = batch['audio_waveform'].cuda()
            midi = batch.get('midi_tokens')
            if midi is not None:
                midi = midi.cuda()
            
            output = model(audio_waveform=audio, midi_tokens=midi)
            if output is not None:
                all_preds.append(output['scores'].cpu().numpy())
                all_targets.append(batch['labels'].numpy())
    
    comparison_preds[model_type] = {
        'preds': np.concatenate(all_preds, axis=0),
        'targets': np.concatenate(all_targets, axis=0),
    }
    
    del model
    torch.cuda.empty_cache()

# Perform paired t-test for each dimension
if len(comparison_preds) == 2:
    print(f"\n{'Dimension':<25} {'p-value':>12} {'Winner':>15} {'Significant?':>12}")
    print("-"*70)
    
    n_samples = min(len(comparison_preds[best_fusion]['preds']), 
                    len(comparison_preds[best_single]['preds']))
    
    significant_wins = {best_fusion: 0, best_single: 0}
    
    for dim_idx, dim in enumerate(CONFIG['dimensions']):
        # Calculate absolute errors for each model
        fusion_errors = np.abs(
            comparison_preds[best_fusion]['preds'][:n_samples, dim_idx] - 
            comparison_preds[best_fusion]['targets'][:n_samples, dim_idx]
        )
        single_errors = np.abs(
            comparison_preds[best_single]['preds'][:n_samples, dim_idx] - 
            comparison_preds[best_single]['targets'][:n_samples, dim_idx]
        )
        
        # Paired t-test (lower error is better)
        t_stat, p_value = stats.ttest_rel(fusion_errors, single_errors)
        
        fusion_mean = np.mean(fusion_errors)
        single_mean = np.mean(single_errors)
        
        if p_value < 0.05:
            if fusion_mean < single_mean:
                winner = best_fusion
                significant_wins[best_fusion] += 1
            else:
                winner = best_single
                significant_wins[best_single] += 1
            sig_str = "Yes"
        else:
            winner = "Tie"
            sig_str = "No"
        
        print(f"{dim:<25} {p_value:>12.4f} {winner:>15} {sig_str:>12}")
    
    print("-"*70)
    print(f"\nSignificant wins: {best_fusion}: {significant_wins[best_fusion]}, {best_single}: {significant_wins[best_single]}")
    
    # Overall conclusion
    print("\n" + "="*90)
    print("STATISTICAL CONCLUSION")
    print("="*90)
    
    if significant_wins[best_fusion] > significant_wins[best_single]:
        print(f"\n{best_fusion} is STATISTICALLY BETTER on {significant_wins[best_fusion]}/{len(CONFIG['dimensions'])} dimensions")
    elif significant_wins[best_single] > significant_wins[best_fusion]:
        print(f"\n{best_single} is STATISTICALLY BETTER on {significant_wins[best_single]}/{len(CONFIG['dimensions'])} dimensions")
        print("WARNING: Single-modal beats fusion - fusion architecture may need debugging")
    else:
        print(f"\nNo clear statistical winner - models perform similarly")
else:
    print("\nCould not load both models for comparison")

## Step 6d: Ablation Study (OPTIONAL)

**Skip this section if Phase 2 passed and you want to proceed directly to Phase 3.**

Run ablation experiments to measure the impact of each training component:
- Huber vs MSE loss
- LDS enabled vs disabled
- FDS enabled vs disabled
- CORAL enabled vs disabled
- Bootstrap enabled vs disabled
- Modality dropout vs no dropout

This helps quantify the contribution of each training improvement.

**Note:** This trains 6 NEW model configurations (3 epochs each) = ~18 GPU-hours on A100.
Only run if debugging or for research purposes.

In [None]:
# Step 6d: Ablation Study (OPTIONAL)
# Skip this cell if you want to proceed directly to Phase 3

# Set to False to skip ablation study
RUN_ABLATION = False  # <-- Set to True to run ablation experiments

if not RUN_ABLATION:
    print("="*70)
    print("ABLATION STUDY: SKIPPED")
    print("="*70)
    print("\nSet RUN_ABLATION = True to run ablation experiments")
    print("This trains 7 model configurations (3 epochs each)")
    ablation_results = None
else:
    print("="*80)
    print("ABLATION STUDY")
    print("="*80)
    
    # Use the best fusion type from Phase 2
    ABLATION_FUSION_TYPE = phase2_gate_result['best_fusion']
    print(f"Using fusion type: {ABLATION_FUSION_TYPE}")
    
    # Define ablation configurations
    # Each config modifies ONE variable from the full config
    ablation_configs = {
        'full': {
            'base_loss': 'huber',
            'lds_enabled': True,
            'fds_enabled': True,
            'coral_enabled': True,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': True,
        },
        'mse_loss': {
            'base_loss': 'mse',
            'lds_enabled': True,
            'fds_enabled': True,
            'coral_enabled': True,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': True,
        },
        'no_lds': {
            'base_loss': 'huber',
            'lds_enabled': False,
            'fds_enabled': True,
            'coral_enabled': True,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': True,
        },
        'no_fds': {
            'base_loss': 'huber',
            'lds_enabled': True,
            'fds_enabled': False,
            'coral_enabled': True,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': True,
        },
        'no_coral': {
            'base_loss': 'huber',
            'lds_enabled': True,
            'fds_enabled': True,
            'coral_enabled': False,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': True,
        },
        'no_bootstrap': {
            'base_loss': 'huber',
            'lds_enabled': True,
            'fds_enabled': True,
            'coral_enabled': True,
            'bootstrap_enabled': False,
            'modality_dropout_enabled': True,
        },
        'no_modality_dropout': {
            'base_loss': 'huber',
            'lds_enabled': True,
            'fds_enabled': True,
            'coral_enabled': True,
            'bootstrap_enabled': True,
            'modality_dropout_enabled': False,
        },
    }
    
    ablation_results = {}
    
    for ablation_name, ablation_cfg in ablation_configs.items():
        print(f"\n{'='*60}")
        print(f"ABLATION: {ablation_name}")
        print(f"{'='*60}")
        
        # Create checkpoint directory
        ablation_ckpt_dir = Path(f'{CHECKPOINT_ROOT}/ablation/{ablation_name}')
        ablation_ckpt_dir.mkdir(parents=True, exist_ok=True)
        
        # Check for existing checkpoint
        resume_ckpt = find_latest_checkpoint(ablation_ckpt_dir)
        if resume_ckpt:
            print(f"Resuming from: {resume_ckpt}")
        
        # Create dataloader with appropriate modality dropout
        if ablation_cfg['modality_dropout_enabled']:
            abl_modality_dropout = {'audio_prob': 0.15, 'midi_prob': 0.15}
        else:
            abl_modality_dropout = None
        
        abl_train_loader, abl_val_loader, _ = create_dataloaders(
            train_annotation_path=CONFIG['train_path'],
            val_annotation_path=CONFIG['val_path'],
            test_annotation_path=CONFIG['test_path'],
            dimension_names=CONFIG['dimensions'],
            batch_size=CONFIG['batch_size'],
            num_workers=4,
            modality_dropout_config=abl_modality_dropout,
        )
        
        # Create model with ablation config
        model = PerformanceEvaluationModel(
            audio_dim=CONFIG['audio_dim'],
            midi_dim=CONFIG['midi_dim'],
            shared_dim=CONFIG['shared_dim'],
            aggregator_dim=512,
            num_dimensions=len(CONFIG['dimensions']),
            dimension_names=CONFIG['dimensions'],
            fusion_type=ABLATION_FUSION_TYPE,
            use_projection=CONFIG['use_projection'],
            mse_weight=CONFIG['mse_weight'],
            ranking_weight=CONFIG['ranking_weight'],
            contrastive_weight=CONFIG['contrastive_weight'],
            base_loss=ablation_cfg['base_loss'],
            huber_delta=CONFIG['huber_delta'],
            lds_enabled=ablation_cfg['lds_enabled'],
            lds_num_bins=CONFIG['lds_num_bins'],
            lds_sigma=CONFIG['lds_sigma'],
            lds_reweight_scale=CONFIG['lds_reweight_scale'],
            fds_enabled=ablation_cfg['fds_enabled'],
            fds_num_bins=CONFIG['fds_num_bins'],
            fds_momentum=CONFIG['fds_momentum'],
            fds_kernel_sigma=CONFIG['fds_kernel_sigma'],
            fds_start_epoch=CONFIG['fds_start_epoch'],
            coral_enabled=ablation_cfg['coral_enabled'],
            coral_num_classes=CONFIG['coral_num_classes'],
            coral_weight=CONFIG['coral_weight'],
            bootstrap_enabled=ablation_cfg['bootstrap_enabled'],
            bootstrap_beta=CONFIG['bootstrap_beta'],
            bootstrap_warmup_epochs=CONFIG['bootstrap_warmup_epochs'],
            backbone_lr=CONFIG['backbone_lr'],
            heads_lr=CONFIG['heads_lr'],
            warmup_steps=CONFIG['warmup_steps'],
            max_epochs=3,  # Shorter for ablation
            gradient_checkpointing=True,
        )
        
        # Fit LDS if enabled
        if ablation_cfg['lds_enabled']:
            model.fit_lds(all_train_labels)
        
        # Callbacks
        callbacks = [
            ModelCheckpoint(
                dirpath=str(ablation_ckpt_dir),
                filename=f'{ablation_name}-{{epoch:02d}}-{{val_loss:.4f}}',
                monitor='val_loss',
                mode='min',
                save_top_k=1,
                save_last=True,
            ),
            EarlyStopping(monitor='val_loss', patience=2, mode='min'),
        ]
        
        # Train
        trainer = pl.Trainer(
            max_epochs=3,
            precision='16-mixed',
            accelerator='auto',
            devices='auto',
            callbacks=callbacks,
            logger=TensorBoardLogger(save_dir='logs', name=f'ablation_{ablation_name}'),
            log_every_n_steps=50,
            gradient_clip_val=1.0,
            accumulate_grad_batches=2,
            val_check_interval=0.5,
        )
        
        trainer.fit(model, abl_train_loader, abl_val_loader, ckpt_path=resume_ckpt)
        
        # Evaluate on test set
        test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
        
        # Store results
        mean_r = np.mean([test_results[0].get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])
        mean_mae = np.mean([test_results[0].get(f'test_mae_{d}', 0) for d in CONFIG['dimensions']])
        
        ablation_results[ablation_name] = {
            'config': ablation_cfg,
            'mean_pearson': mean_r,
            'mean_mae': mean_mae,
            'val_loss': float(callbacks[0].best_model_score) if callbacks[0].best_model_score else None,
        }
        
        print(f"\nResults: r={mean_r:.3f}, MAE={mean_mae:.2f}")
        
        # Sync to Google Drive
        !rclone copy {ablation_ckpt_dir} {GDRIVE_CHECKPOINT_PATH}/ablation/{ablation_name} --progress
        
        # Cleanup
        del model, trainer
        gc.collect()
        torch.cuda.empty_cache()
    
    # Print ablation summary
    print("\n" + "="*80)
    print("ABLATION STUDY RESULTS")
    print("="*80)
    print(f"{'Config':<25} {'Pearson r':>12} {'MAE':>12} {'Delta r':>12}")
    print("-"*65)
    
    full_r = ablation_results['full']['mean_pearson']
    for name, res in ablation_results.items():
        delta = res['mean_pearson'] - full_r
        delta_str = f"{delta:+.3f}" if name != 'full' else "baseline"
        print(f"{name:<25} {res['mean_pearson']:>12.3f} {res['mean_mae']:>12.2f} {delta_str:>12}")
    
    print("\n" + "-"*65)
    print("Interpretation:")
    print("  - Negative delta = component HELPS (removing it hurts performance)")
    print("  - Positive delta = component HURTS (removing it improves performance)")
    
    # Identify most impactful components
    print("\n" + "="*80)
    print("COMPONENT IMPACT RANKING")
    print("="*80)
    
    impacts = {name: full_r - res['mean_pearson'] 
               for name, res in ablation_results.items() if name != 'full'}
    sorted_impacts = sorted(impacts.items(), key=lambda x: abs(x[1]), reverse=True)
    
    print(f"\n{'Component':<25} {'Impact':>12} {'Effect':<15}")
    print("-"*55)
    for name, impact in sorted_impacts:
        effect = "HELPS" if impact > 0 else "HURTS"
        print(f"{name:<25} {abs(impact):>12.3f} {effect:<15}")

## Step 7: Phase 3 - Contrastive Pre-training

**Prerequisite: Phase 2 must pass (fusion beats single-modal by >= 10%)**

Align MERT and MIDIBert representation spaces using InfoNCE contrastive loss with hard negative mining.

**Training Config:**
- Freeze encoders, train only projection heads
- Larger batch size (64) for more in-batch negatives
- Hard negative mining: 25% of batch from same piece, different degradation
- 15 epochs

**Success Criteria:**
- Cross-modal alignment score >= 0.6

**Gate Logic:**
- If PHASE2_PASSED is False, this cell will skip training
- After training, Phase 3 gate check determines if we proceed to Phase 4

In [None]:
if not PHASE2_PASSED:
    print("SKIPPING Phase 3: Phase 2 gate check failed")
    print("Debug fusion architecture before proceeding")
else:
    print("="*70)
    print("PHASE 3: CONTRASTIVE PRE-TRAINING")
    print("="*70)
    
    # Configuration for contrastive pre-training
    CONTRASTIVE_CONFIG = {
        'epochs': 15,
        'batch_size': 64,  # Larger batch for more in-batch negatives
        'temperature': 0.07,
        'learning_rate': 1e-4,
        'warmup_epochs': 2,
        'freeze_encoders': True,
        'use_hard_negatives': True,
        'hard_neg_ratio': 0.25,  # 25% of batch are hard negatives
        'alignment_target': 0.6,  # Phase 3 gate criterion
    }
    
    print(f"\nContrastive Config:")
    for k, v in CONTRASTIVE_CONFIG.items():
        print(f"  {k}: {v}")
    
    # Create contrastive dataloaders with hard negative mining
    from src.data.dataset import create_contrastive_dataloader
    
    print("\nCreating contrastive dataloaders with hard negative mining...")
    contrastive_train_loader = create_contrastive_dataloader(
        annotation_path=CONFIG['train_path'],
        dimension_names=CONFIG['dimensions'],
        batch_size=CONTRASTIVE_CONFIG['batch_size'],
        num_workers=4,
        use_hard_negatives=CONTRASTIVE_CONFIG['use_hard_negatives'],
        hard_neg_ratio=CONTRASTIVE_CONFIG['hard_neg_ratio'],
    )
    
    contrastive_val_loader = create_contrastive_dataloader(
        annotation_path=CONFIG['val_path'],
        dimension_names=CONFIG['dimensions'],
        batch_size=CONTRASTIVE_CONFIG['batch_size'],
        num_workers=4,
        use_hard_negatives=False,  # No hard negatives for validation
    )
    
    print(f"Train batches: {len(contrastive_train_loader)}")
    print(f"Val batches: {len(contrastive_val_loader)}")
    
    # Create model in contrastive training mode
    print("\nCreating model in contrastive-only mode...")
    contrastive_model = PerformanceEvaluationModel(
        audio_dim=CONFIG['audio_dim'],
        midi_dim=CONFIG['midi_dim'],
        shared_dim=CONFIG['shared_dim'],
        training_mode="contrastive",  # Contrastive-only mode
        modality="fusion",
        fusion_type=best_fusion,  # Use best fusion from Phase 2
        use_projection=True,
        freeze_audio_encoder=CONTRASTIVE_CONFIG['freeze_encoders'],
        gradient_checkpointing=True,
        midi_pretrained_checkpoint=midi_pretrained_local,
        contrastive_temperature=CONTRASTIVE_CONFIG['temperature'],
        contrastive_weight=1.0,
        learning_rate=CONTRASTIVE_CONFIG['learning_rate'],
        backbone_lr=0,  # Frozen encoders
        heads_lr=CONTRASTIVE_CONFIG['learning_rate'],
        warmup_steps=len(contrastive_train_loader) * CONTRASTIVE_CONFIG['warmup_epochs'],
        max_epochs=CONTRASTIVE_CONFIG['epochs'],
    )
    
    # Freeze encoders
    if CONTRASTIVE_CONFIG['freeze_encoders']:
        print("Freezing encoders, training only projection heads...")
        if contrastive_model.audio_encoder is not None:
            for param in contrastive_model.audio_encoder.parameters():
                param.requires_grad = False
        if contrastive_model.midi_encoder is not None:
            for param in contrastive_model.midi_encoder.parameters():
                param.requires_grad = False
    
    # Count trainable parameters
    trainable = sum(p.numel() for p in contrastive_model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in contrastive_model.parameters())
    print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
    
    # Checkpoint directory
    contrastive_ckpt_dir = Path(f'{CHECKPOINT_ROOT}/contrastive')
    contrastive_ckpt_dir.mkdir(parents=True, exist_ok=True)
    
    # Callbacks
    contrastive_callbacks = [
        ModelCheckpoint(
            dirpath=str(contrastive_ckpt_dir),
            filename='contrastive-{epoch:02d}-{val_alignment_score:.4f}',
            monitor='val_alignment_score',
            mode='max',  # Higher alignment is better
            save_top_k=1,
            save_last=True,
        ),
        EarlyStopping(
            monitor='val_alignment_score',
            patience=5,
            mode='max',
        ),
        LearningRateMonitor(logging_interval='step'),
    ]
    
    # Trainer
    contrastive_trainer = pl.Trainer(
        max_epochs=CONTRASTIVE_CONFIG['epochs'],
        precision='16-mixed',
        accelerator='auto',
        devices='auto',
        callbacks=contrastive_callbacks,
        logger=TensorBoardLogger(save_dir='logs', name='contrastive'),
        log_every_n_steps=50,
        gradient_clip_val=1.0,
        val_check_interval=0.5,
    )
    
    # Train
    print("\nStarting contrastive pre-training...")
    contrastive_trainer.fit(contrastive_model, contrastive_train_loader, contrastive_val_loader)
    
    # Get best results
    best_alignment = contrastive_callbacks[0].best_model_score
    best_contrastive_ckpt = contrastive_callbacks[0].best_model_path
    
    print(f"\nBest alignment score: {best_alignment:.4f}")
    print(f"Best checkpoint: {best_contrastive_ckpt}")
    
    # Sync to Google Drive
    print("\nSyncing to Google Drive...")
    !rclone copy {contrastive_ckpt_dir} {GDRIVE_CHECKPOINT_PATH}/contrastive --progress
    
    # Cleanup
    del contrastive_model, contrastive_trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    print("\nContrastive pre-training complete!")

In [None]:
# Phase 3 Gate Check
print("="*70)
print("PHASE 3 GATE CHECK (TRAINING_PLAN_v2.md)")
print("="*70)

if not PHASE2_PASSED:
    print("Phase 2 did not pass - skipping Phase 3 gate check")
    PHASE3_PASSED = False
else:
    alignment_target = CONTRASTIVE_CONFIG['alignment_target']
    
    print(f"\nTarget alignment: >= {alignment_target}")
    print(f"Achieved alignment: {best_alignment:.4f}")
    
    PHASE3_PASSED = best_alignment >= alignment_target
    
    if PHASE3_PASSED:
        print(f"\n{'='*40}")
        print("PHASE 3 GATE: PASS")
        print(f"{'='*40}")
        print(f"Cross-modal alignment ({best_alignment:.4f}) >= target ({alignment_target})")
        print("-> Proceed to Phase 4: Full training with aligned encoders")
        print(f"\nContrastive checkpoint to use: {best_contrastive_ckpt}")
    else:
        print(f"\n{'='*40}")
        print("PHASE 3 GATE: FAIL")
        print(f"{'='*40}")
        print(f"Cross-modal alignment ({best_alignment:.4f}) < target ({alignment_target})")
        print("-> Consider:")
        print("  1. More epochs")
        print("  2. Different temperature (try 0.05 or 0.1)")
        print("  3. Increase hard negative ratio")
        print("  4. Check data quality")

# Store gate result
phase3_gate_result = {
    'passed': PHASE3_PASSED,
    'alignment_score': float(best_alignment) if 'best_alignment' in dir() else None,
    'target': alignment_target if 'alignment_target' in dir() else 0.6,
    'checkpoint': best_contrastive_ckpt if 'best_contrastive_ckpt' in dir() else None,
}

## Step 8: Phase 4 - Fine-tuning with Aligned Encoders

**Prerequisite: Phase 3 must pass (alignment score >= 0.6)**

Fine-tune the best fusion model using the contrastive-pretrained encoder weights.

**Training Config:**
- Load projection heads from contrastive checkpoint
- Unfreeze top layers of encoders gradually
- Full loss function (regression + ranking + contrastive)
- 10 epochs

In [None]:
if not PHASE3_PASSED:
    print("SKIPPING Phase 4: Phase 3 gate check failed")
    print("Alignment score did not meet target. Debug contrastive training before proceeding.")
    phase4_results = None
else:
    print("="*70)
    print("PHASE 4: FINE-TUNING WITH ALIGNED ENCODERS")
    print("="*70)
    
    # Configuration for Phase 4
    PHASE4_CONFIG = {
        'epochs': 20,                # Full training: 20 epochs
        'batch_size': 8,
        'unfreeze_encoder_epoch': 5, # Start unfreezing encoders after epoch 5
        'encoder_lr_scale': 0.1,     # 10x lower LR for encoders
        'target_technical_r': 0.50,  # Phase 4 gate criterion
        'target_interpretive_r': 0.35,
        'target_fusion_improvement': 15.0,  # Must beat single-modal by 15%
    }
    
    print(f"\nPhase 4 Config:")
    for k, v in PHASE4_CONFIG.items():
        print(f"  {k}: {v}")
    
    best_fusion = phase2_gate_result['best_fusion']
    print(f"\nUsing best fusion type from Phase 2: {best_fusion}")
    print(f"Loading contrastive-pretrained weights from: {best_contrastive_ckpt}")
    
    # Create checkpoint directory
    phase4_ckpt_dir = Path(f'{CHECKPOINT_ROOT}/phase4_{best_fusion}')
    phase4_ckpt_dir.mkdir(parents=True, exist_ok=True)
    
    # Check for existing checkpoint to resume
    resume_ckpt = find_latest_checkpoint(phase4_ckpt_dir)
    if resume_ckpt:
        print(f"Found checkpoint to resume: {resume_ckpt}")
    
    # Load contrastive model to get projection head weights
    print("\nLoading contrastive-pretrained projection heads...")
    contrastive_state = torch.load(best_contrastive_ckpt, map_location='cpu')
    
    # Extract projection head weights
    projection_weights = {}
    for key, value in contrastive_state['state_dict'].items():
        if 'audio_projection' in key or 'midi_projection' in key:
            projection_weights[key] = value
    print(f"Loaded {len(projection_weights)} projection head parameters")
    
    # Create model for Phase 4
    phase4_model = PerformanceEvaluationModel(
        audio_dim=CONFIG['audio_dim'],
        midi_dim=CONFIG['midi_dim'],
        shared_dim=CONFIG['shared_dim'],
        aggregator_dim=512,
        num_dimensions=len(CONFIG['dimensions']),
        dimension_names=CONFIG['dimensions'],
        modality="fusion",
        fusion_type=best_fusion,
        use_projection=True,
        freeze_audio_encoder=False,  # Will unfreeze gradually
        gradient_checkpointing=True,
        midi_pretrained_checkpoint=midi_pretrained_local,
        # Loss weights - include contrastive to maintain alignment
        mse_weight=CONFIG['mse_weight'],
        ranking_weight=CONFIG['ranking_weight'],
        contrastive_weight=CONFIG['contrastive_weight'],
        # Base loss
        base_loss=CONFIG['base_loss'],
        huber_delta=CONFIG['huber_delta'],
        # LDS
        lds_enabled=CONFIG['lds_enabled'],
        lds_num_bins=CONFIG['lds_num_bins'],
        lds_sigma=CONFIG['lds_sigma'],
        lds_reweight_scale=CONFIG['lds_reweight_scale'],
        # FDS
        fds_enabled=CONFIG['fds_enabled'],
        fds_num_bins=CONFIG['fds_num_bins'],
        fds_momentum=CONFIG['fds_momentum'],
        fds_kernel_sigma=CONFIG['fds_kernel_sigma'],
        fds_start_epoch=CONFIG['fds_start_epoch'],
        # CORAL
        coral_enabled=CONFIG['coral_enabled'],
        coral_num_classes=CONFIG['coral_num_classes'],
        coral_weight=CONFIG['coral_weight'],
        # Bootstrap
        bootstrap_enabled=CONFIG['bootstrap_enabled'],
        bootstrap_beta=CONFIG['bootstrap_beta'],
        bootstrap_warmup_epochs=CONFIG['bootstrap_warmup_epochs'],
        # Training hyperparameters
        backbone_lr=CONFIG['backbone_lr'] * PHASE4_CONFIG['encoder_lr_scale'],  # Lower LR for encoders
        heads_lr=CONFIG['heads_lr'],
        warmup_steps=len(train_loader) * 2,
        max_epochs=PHASE4_CONFIG['epochs'],
    )
    
    # Fit LDS if enabled
    if CONFIG['lds_enabled']:
        phase4_model.fit_lds(all_train_labels)
    
    # Load projection head weights from contrastive pre-training
    missing, unexpected = phase4_model.load_state_dict(projection_weights, strict=False)
    print(f"Loaded projection heads: {len(projection_weights) - len(missing)} params")
    
    # Freeze encoders initially
    print("\nFreezing encoders for initial training...")
    if phase4_model.audio_encoder is not None:
        for param in phase4_model.audio_encoder.parameters():
            param.requires_grad = False
    if phase4_model.midi_encoder is not None:
        for param in phase4_model.midi_encoder.parameters():
            param.requires_grad = False
    
    # Count trainable parameters
    trainable = sum(p.numel() for p in phase4_model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in phase4_model.parameters())
    print(f"Initial trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
    
    # Callbacks
    phase4_callbacks = [
        ModelCheckpoint(
            dirpath=str(phase4_ckpt_dir),
            filename=f'phase4_{best_fusion}-{{epoch:02d}}-{{val_loss:.4f}}',
            monitor='val_loss',
            mode='min',
            save_top_k=2,
            save_last=True,
        ),
        EarlyStopping(
            monitor='val_loss',
            patience=5,  # More patience for longer training
            mode='min',
        ),
        LearningRateMonitor(logging_interval='step'),
    ]
    
    # Add staged unfreezing callback
    phase4_unfreezing_schedule = [
        {'epoch': 0, 'freeze': ['audio_encoder', 'midi_encoder'], 'unfreeze': ['projection', 'fusion', 'heads']},
        {'epoch': PHASE4_CONFIG['unfreeze_encoder_epoch'], 'unfreeze': ['audio_encoder.top_4', 'midi_encoder.top_2'], 'lr_scale': 0.1},
        {'epoch': PHASE4_CONFIG['unfreeze_encoder_epoch'] + 5, 'unfreeze': ['audio_encoder', 'midi_encoder'], 'lr_scale': 0.05},
    ]
    phase4_callbacks.append(StagedUnfreezingCallback(
        schedule=phase4_unfreezing_schedule,
        verbose=True,
    ))
    
    # Trainer
    phase4_trainer = pl.Trainer(
        max_epochs=PHASE4_CONFIG['epochs'],
        precision='16-mixed',
        accelerator='auto',
        devices='auto',
        callbacks=phase4_callbacks,
        logger=TensorBoardLogger(save_dir='logs', name=f'phase4_{best_fusion}'),
        log_every_n_steps=50,
        gradient_clip_val=1.0,
        accumulate_grad_batches=2,
        val_check_interval=0.5,
    )
    
    # Train
    print("\nStarting Phase 4 training...")
    phase4_trainer.fit(phase4_model, train_loader, val_loader, ckpt_path=resume_ckpt)
    
    # Get best checkpoint
    best_phase4_ckpt = phase4_callbacks[0].best_model_path
    best_phase4_val_loss = float(phase4_callbacks[0].best_model_score) if phase4_callbacks[0].best_model_score else None
    
    print(f"\nBest Phase 4 checkpoint: {best_phase4_ckpt}")
    print(f"Best val loss: {best_phase4_val_loss:.4f}" if best_phase4_val_loss else "N/A")
    
    # Evaluate on test set
    print("\nEvaluating on test set...")
    phase4_test_results = phase4_trainer.test(phase4_model, test_loader, ckpt_path=best_phase4_ckpt)[0]
    
    # Calculate summary metrics
    technical_dims = ['note_accuracy', 'rhythmic_stability', 'articulation_clarity', 'pedal_technique']
    interpretive_dims = ['musical_expression', 'overall_interpretation']
    
    technical_r = np.mean([phase4_test_results.get(f'test_pearson_{d}', 0) for d in technical_dims])
    interpretive_r = np.mean([phase4_test_results.get(f'test_pearson_{d}', 0) for d in interpretive_dims])
    mean_r = np.mean([phase4_test_results.get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])
    mean_mae = np.mean([phase4_test_results.get(f'test_mae_{d}', 0) for d in CONFIG['dimensions']])
    
    # Sync to Google Drive
    print("\nSyncing Phase 4 checkpoint to Google Drive...")
    !rclone copy {phase4_ckpt_dir} {GDRIVE_CHECKPOINT_PATH}/phase4_{best_fusion} --progress
    
    # Cleanup
    del phase4_model, phase4_trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    # Store results
    phase4_results = {
        'fusion_type': best_fusion,
        'checkpoint': best_phase4_ckpt,
        'val_loss': best_phase4_val_loss,
        'test_metrics': phase4_test_results,
        'technical_r': technical_r,
        'interpretive_r': interpretive_r,
        'mean_pearson': mean_r,
        'mean_mae': mean_mae,
    }
    
    print("\nPhase 4 training complete!")

In [None]:
# Phase 4 Gate Check
print("="*70)
print("PHASE 4 GATE CHECK (TRAINING_PLAN_v2.md)")
print("="*70)

if not PHASE3_PASSED or phase4_results is None:
    print("Phase 3 did not pass or Phase 4 was skipped")
    PHASE4_PASSED = False
else:
    technical_target = PHASE4_CONFIG['target_technical_r']
    interpretive_target = PHASE4_CONFIG['target_interpretive_r']
    fusion_improvement_target = PHASE4_CONFIG['target_fusion_improvement']
    
    # Get Phase 2 best single-modal for comparison
    phase2_best_single_r = phase2_gate_result['best_single_r']
    actual_improvement = ((phase4_results['mean_pearson'] - phase2_best_single_r) / phase2_best_single_r * 100)
    
    print(f"\nPhase 4 Results:")
    print(f"  Technical dimensions: r = {phase4_results['technical_r']:.4f} (target: >= {technical_target})")
    print(f"  Interpretive dimensions: r = {phase4_results['interpretive_r']:.4f} (target: >= {interpretive_target})")
    print(f"  Improvement over single-modal: {actual_improvement:+.1f}% (target: >= {fusion_improvement_target}%)")
    
    technical_passed = phase4_results['technical_r'] >= technical_target
    interpretive_passed = phase4_results['interpretive_r'] >= interpretive_target
    improvement_passed = actual_improvement >= fusion_improvement_target
    
    PHASE4_PASSED = technical_passed and interpretive_passed and improvement_passed
    
    print(f"\n{'='*40}")
    if PHASE4_PASSED:
        print("PHASE 4 GATE: PASS")
        print(f"{'='*40}")
        print("All targets met!")
        print("-> Ready for Phase 5: Expert Validation")
        print(f"\nRecommended checkpoint: {phase4_results['checkpoint']}")
    else:
        print("PHASE 4 GATE: FAIL")
        print(f"{'='*40}")
        print("Targets not met:")
        if not technical_passed:
            print(f"  - Technical r ({phase4_results['technical_r']:.3f}) < target ({technical_target})")
        if not interpretive_passed:
            print(f"  - Interpretive r ({phase4_results['interpretive_r']:.3f}) < target ({interpretive_target})")
        if not improvement_passed:
            print(f"  - Improvement ({actual_improvement:.1f}%) < target ({fusion_improvement_target}%)")
        print("\n-> Consider:")
        print("  1. More training epochs")
        print("  2. Different fusion architecture")
        print("  3. Verify degradation labels quality")
        print("  4. Check data distribution")

# Store gate result
phase4_gate_result = {
    'passed': PHASE4_PASSED,
    'technical_r': phase4_results['technical_r'] if phase4_results else None,
    'interpretive_r': phase4_results['interpretive_r'] if phase4_results else None,
    'improvement_over_single_modal': actual_improvement if phase4_results else None,
    'targets': {
        'technical': technical_target if 'technical_target' in dir() else 0.50,
        'interpretive': interpretive_target if 'interpretive_target' in dir() else 0.35,
        'improvement': fusion_improvement_target if 'fusion_improvement_target' in dir() else 15.0,
    },
    'checkpoint': phase4_results['checkpoint'] if phase4_results else None,
}

## Step 9: Save All Results

Save comprehensive results including:
- All 5 Phase 2 models (audio_only, midi_only, crossattn, gated, concat)
- Phase 2 gate check results
- Phase 3 contrastive pre-training results (if completed)
- Phase 4 fine-tuning results (if completed)
- Ablation study results (if completed)
- Configuration and hyperparameters

All results are saved to:
- Local: `/tmp/checkpoints/`
- Google Drive: `gdrive:crescendai_checkpoints/fusion_comparison/`

In [None]:
# Step 9: Save All Results
import json
from datetime import datetime

print("="*80)
print("SAVING ALL RESULTS")
print("="*80)

# Compile comprehensive results
final_results = {
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    
    # Training improvements applied
    'training_improvements': {
        'base_loss': CONFIG['base_loss'],
        'huber_delta': CONFIG['huber_delta'],
        'lds_enabled': CONFIG['lds_enabled'],
        'fds_enabled': CONFIG['fds_enabled'],
        'coral_enabled': CONFIG['coral_enabled'],
        'bootstrap_enabled': CONFIG['bootstrap_enabled'],
        'modality_dropout_enabled': CONFIG['modality_dropout']['enabled'],
        'staged_unfreezing_enabled': CONFIG['staged_unfreezing']['enabled'],
    },
    
    # Phase 2: All 5 models
    'phase2_models': {},
    'phase2_gate': phase2_gate_result if 'phase2_gate_result' in dir() else None,
    
    # Phase 3: Contrastive pre-training
    'phase3_gate': phase3_gate_result if 'phase3_gate_result' in dir() else None,
    
    # Phase 4: Fine-tuning
    'phase4_results': phase4_results if 'phase4_results' in dir() else None,
    
    # Ablation study
    'ablation_results': ablation_results if 'ablation_results' in dir() and ablation_results else None,
}

# Add all Phase 2 model results
single_modal_types = ['audio_only', 'midi_only']
all_model_types = single_modal_types + CONFIG['fusion_types']

for model_type in all_model_types:
    if model_type in phase2_results:
        # Get checkpoint path
        if model_type in trained_models:
            ckpt_path = trained_models[model_type].get('best_checkpoint')
            val_loss = trained_models[model_type].get('best_val_loss')
        else:
            ckpt_path = None
            val_loss = None
        
        # Calculate summary metrics
        mean_r = np.mean([phase2_results[model_type].get(f'test_pearson_{d}', 0) for d in CONFIG['dimensions']])
        mean_mae = np.mean([phase2_results[model_type].get(f'test_mae_{d}', 0) for d in CONFIG['dimensions']])
        
        final_results['phase2_models'][model_type] = {
            'checkpoint': ckpt_path,
            'val_loss': val_loss,
            'metrics': phase2_results[model_type],
            'mean_pearson': mean_r,
            'mean_mae': mean_mae,
            'type': 'single_modal' if model_type in single_modal_types else 'fusion',
        }

# Save to local file
results_path = f'{CHECKPOINT_ROOT}/comprehensive_results.json'
with open(results_path, 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

print(f"\nResults saved to: {results_path}")

# Also save ablation results separately if they exist
if ablation_results:
    ablation_path = f'{CHECKPOINT_ROOT}/ablation_results.json'
    with open(ablation_path, 'w') as f:
        json.dump(ablation_results, f, indent=2, default=str)
    print(f"Ablation results saved to: {ablation_path}")

# Sync to Google Drive
print("\nSyncing all results to Google Drive...")
!rclone copy {CHECKPOINT_ROOT} {GDRIVE_CHECKPOINT_PATH} --progress
print("Sync complete!")

# ============================================================================
# FINAL SUMMARY TABLE
# ============================================================================
print("\n" + "="*90)
print("FINAL SUMMARY")
print("="*90)

print("\n" + "-"*90)
print("PHASE 2: MODEL COMPARISON (5 Models)")
print("-"*90)
print(f"{'Model':<15} {'Mean r':>10} {'Mean MAE':>12} {'Type':<15} {'Status'}")
print("-"*90)

for model_type in all_model_types:
    if model_type in final_results['phase2_models']:
        m = final_results['phase2_models'][model_type]
        status = "BEST" if model_type == phase2_gate_result.get('best_fusion') or model_type == phase2_gate_result.get('best_single_modal') else ""
        print(f"{model_type:<15} {m['mean_pearson']:>10.4f} {m['mean_mae']:>12.2f} {m['type']:<15} {status}")

print("-"*90)

# Phase 2 gate result
if phase2_gate_result:
    print(f"\nPhase 2 Gate: {'PASS' if phase2_gate_result['passed'] else 'FAIL'}")
    print(f"  Best fusion: {phase2_gate_result['best_fusion']} (r = {phase2_gate_result['best_fusion_r']:.4f})")
    print(f"  Best single-modal: {phase2_gate_result['best_single_modal']} (r = {phase2_gate_result['best_single_r']:.4f})")
    print(f"  Improvement: {phase2_gate_result['improvement']:+.1f}% (target: >= {CONFIG['fusion_improvement_target']:.0f}%)")

# Phase 3 results
print("\n" + "-"*90)
print("PHASE 3: CONTRASTIVE PRE-TRAINING")
print("-"*90)

if 'phase3_gate_result' in dir() and phase3_gate_result and phase3_gate_result.get('alignment_score') is not None:
    print(f"Alignment score: {phase3_gate_result['alignment_score']:.4f} (target: >= {phase3_gate_result['target']})")
    print(f"Gate: {'PASS' if phase3_gate_result['passed'] else 'FAIL'}")
    if phase3_gate_result.get('checkpoint'):
        print(f"Checkpoint: {phase3_gate_result['checkpoint']}")
else:
    print("Not completed (Phase 2 did not pass or skipped)")

# Phase 4 results
print("\n" + "-"*90)
print("PHASE 4: FINE-TUNING WITH ALIGNED ENCODERS")
print("-"*90)

if 'phase4_results' in dir() and phase4_results:
    print(f"Mean Pearson r: {phase4_results['mean_pearson']:.4f}")
    print(f"Mean MAE: {phase4_results['mean_mae']:.2f}")
    
    # Improvement over Phase 2
    phase2_best = phase2_gate_result['best_fusion_r'] if phase2_gate_result else 0
    if phase2_best > 0:
        improvement = ((phase4_results['mean_pearson'] - phase2_best) / phase2_best * 100)
        print(f"Improvement over Phase 2: {improvement:+.1f}%")
    
    print(f"Checkpoint: {phase4_results['checkpoint']}")
else:
    print("Not completed (Phase 3 did not pass or skipped)")

# Ablation summary
if ablation_results:
    print("\n" + "-"*90)
    print("ABLATION STUDY SUMMARY")
    print("-"*90)
    
    full_r = ablation_results['full']['mean_pearson']
    impacts = [(name, full_r - res['mean_pearson']) 
               for name, res in ablation_results.items() if name != 'full']
    sorted_impacts = sorted(impacts, key=lambda x: abs(x[1]), reverse=True)
    
    print(f"{'Component':<25} {'Impact':>12} {'Effect'}")
    for name, impact in sorted_impacts[:5]:  # Top 5 impacts
        effect = "HELPS" if impact > 0 else "HURTS"
        print(f"{name:<25} {abs(impact):>12.3f} {effect}")

# Best model recommendation
print("\n" + "="*90)
print("RECOMMENDATION")
print("="*90)

if 'phase4_results' in dir() and phase4_results:
    print(f"\nBest model: Phase 4 fine-tuned ({phase2_gate_result['best_fusion']})")
    print(f"  Checkpoint: {phase4_results['checkpoint']}")
    print(f"  Performance: r = {phase4_results['mean_pearson']:.4f}")
elif phase2_gate_result and phase2_gate_result['passed']:
    print(f"\nBest model: {phase2_gate_result['best_fusion']} fusion")
    if phase2_gate_result['best_fusion'] in trained_models:
        print(f"  Checkpoint: {trained_models[phase2_gate_result['best_fusion']]['best_checkpoint']}")
    print(f"  Performance: r = {phase2_gate_result['best_fusion_r']:.4f}")
else:
    print(f"\nBest model: {phase2_gate_result['best_single_modal']} (single-modal)")
    print("  NOTE: Fusion did not beat single-modal - architecture needs debugging")

print(f"\nAll checkpoints synced to: {GDRIVE_CHECKPOINT_PATH}")