# Definitive Experiments

## Parts
1. **M1a-M1d**: MuQ Layer Ablation (find optimal layers)
2. **F8-F11**: MuQ + Symbolic Fusion
3. **D9a-D9c**: MERT + MuQ Audio Fusion
4. **X2-X3**: Cross-Dataset Validation (ASAP, PSyllabus)
5. **S3-S4**: Statistical Rigor (Bootstrap, Bonferroni)
6. **A3-A7**: Analysis (Error correlation, dimensions, calibration)
7. **Export**: Save all results to GDrive

## Requirements
- Compute: A100 (80GB VRAM)
- rclone configured with `gdrive:` remote
- External datasets: ASAP, PSyllabus

In [None]:
# Cell 1: CUDA setup (must be before any CUDA operations)
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("GPU required")

In [None]:
# Cell 2: Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Cell 3: Install dependencies and clone repo
!pip install transformers librosa soundfile pytorch_lightning nnAudio scipy scikit-learn muq requests tqdm --quiet

import os
REPO_DIR = '/tmp/crescendai'
if os.path.exists(REPO_DIR):
    !cd {REPO_DIR} && git pull origin main
else:
    !git clone https://github.com/jai-dhiman/crescendai.git {REPO_DIR}

print(f"Repo: {REPO_DIR}")

In [None]:
# Cell 4: Imports
import sys
sys.path.insert(0, f'{REPO_DIR}/model/src')

import json
import subprocess
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import torch
import pytorch_lightning as pl
from scipy import stats
from sklearn.metrics import r2_score

from audio_experiments import PERCEPIANO_DIMENSIONS, DIMENSION_CATEGORIES, BASE_CONFIG, SEED
from audio_experiments.extractors import (
    extract_mert_for_layer_range,
    extract_muq_embeddings,
)
from audio_experiments.models import (
    MuQStatsModel,
    MuQBaseModel,
    MERTMuQEnsemble,
    MERTMuQConcatModel,
    AsymmetricGatedFusion,
)
from audio_experiments.training import (
    run_4fold_mert_experiment,
    run_4fold_dual_experiment,
    restore_all_from_gdrive,
    should_run_experiment,
    sync_experiment_to_gdrive,
    get_completed_experiments,
    print_experiment_status,
    # Fusion runners
    run_simple_fusion_experiment,
    run_weighted_fusion_experiment,
    run_ridge_fusion_experiment,
    run_confidence_fusion_experiment,
    run_error_correlation_experiment,
    save_fusion_experiment,
    # Statistics
    bootstrap_r2_extended,
    bootstrap_r2_comparison,
    paired_ttest_per_sample,
    wilcoxon_test,
    cohens_d,
    bonferroni_correction,
    fdr_correction,
    # Fusion strategies
    simple_average_fusion,
    weighted_fusion_grid_search,
    compute_error_correlation,
    compute_per_dimension_comparison,
)
from audio_experiments.training.sync import numpy_serializer

warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('medium')
pl.seed_everything(SEED, workers=True)

print(f"PyTorch: {torch.__version__}")
print(f"Imports: OK")

In [None]:
# Cell 5: Path configuration
DATA_ROOT = Path('/tmp/definitive_experiments')
AUDIO_DIR = DATA_ROOT / 'audio'
LABEL_DIR = DATA_ROOT / 'labels'
MUQ_CACHE_ROOT = DATA_ROOT / 'muq_cache'
MERT_CACHE_ROOT = DATA_ROOT / 'mert_cache'
CHECKPOINT_ROOT = DATA_ROOT / 'checkpoints'
RESULTS_DIR = DATA_ROOT / 'results'
LOG_DIR = DATA_ROOT / 'logs'
FIGURES_DIR = RESULTS_DIR / 'figures'

# Cross-dataset directories
ASAP_DIR = DATA_ROOT / 'asap'
PSYLLABUS_DIR = DATA_ROOT / 'psyllabus'

# GDrive paths
GDRIVE_AUDIO = 'gdrive:crescendai_data/audio_baseline/percepiano_rendered'
GDRIVE_LABELS = 'gdrive:crescendai_data/percepiano_labels'
GDRIVE_FOLDS = 'gdrive:crescendai_data/percepiano_fold_assignments.json'
GDRIVE_MERT_CACHE = 'gdrive:crescendai_data/audio_baseline/mert_embeddings/L7-12'
GDRIVE_MUQ_CACHE = 'gdrive:crescendai_data/audio_baseline/muq_embeddings'
GDRIVE_RESULTS = 'gdrive:crescendai_data/checkpoints/definitive_experiments'
GDRIVE_SYMBOLIC = 'gdrive:crescendai_data/checkpoints/aligned_fusion/symbolic_predictions.json'

for d in [AUDIO_DIR, LABEL_DIR, MUQ_CACHE_ROOT, MERT_CACHE_ROOT, CHECKPOINT_ROOT,
          RESULTS_DIR, LOG_DIR, FIGURES_DIR, ASAP_DIR, PSYLLABUS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def run_rclone(cmd, desc=""):
    if desc:
        print(f"{desc}...")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"rclone failed: {desc}\nCommand: {' '.join(cmd)}\nStderr: {result.stderr}")
    return result

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone 'gdrive' not configured")

print(f"Data root: {DATA_ROOT}")
print(f"GDrive results: {GDRIVE_RESULTS}")

In [None]:
# Cell 6: Download data
run_rclone(['rclone', 'copy', GDRIVE_AUDIO, str(AUDIO_DIR), '--progress'], "Downloading audio")
run_rclone(['rclone', 'copy', GDRIVE_LABELS, str(LABEL_DIR)], "Downloading labels")

FOLD_FILE = DATA_ROOT / 'folds.json'
run_rclone(['rclone', 'copyto', GDRIVE_FOLDS, str(FOLD_FILE)], "Downloading folds")

# Load labels and folds
LABEL_FILE = LABEL_DIR / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
with open(LABEL_FILE) as f:
    LABELS = json.load(f)
with open(FOLD_FILE) as f:
    FOLD_ASSIGNMENTS = json.load(f)

# Create key->fold_id mapping
FOLD_BY_KEY = {}
for fold_id in range(4):
    for key in FOLD_ASSIGNMENTS.get(f"fold_{fold_id}", []):
        FOLD_BY_KEY[key] = fold_id

ALL_KEYS = sorted(FOLD_BY_KEY.keys())
print(f"Samples per fold: {[len(FOLD_ASSIGNMENTS.get(f'fold_{i}', [])) for i in range(4)]}")
print(f"Total samples: {len(ALL_KEYS)}")
print(f"Audio files: {len(list(AUDIO_DIR.glob('*.wav')))}")

In [None]:
# Cell 7: Initialize results tracking
ALL_RESULTS = {}

# Get completed experiments from GDrive
print("Checking GDrive for completed experiments...")
COMPLETED_CACHE = get_completed_experiments(GDRIVE_RESULTS)
print(f"Found {len(COMPLETED_CACHE)} completed experiments")

# Define experiment IDs
EXPERIMENT_IDS = [
    # Part 1: MuQ Layer Ablation (MuQ has 12 transformer layers, indices 0-12)
    'M1a_muq_L1-4',
    'M1b_muq_L5-8',
    'M1c_muq_L9-12',
    'M1d_muq_L1-12',
    'M2_muq_last_hidden',  # Test last_hidden_state (replicates D8 from phase 2)
    # Part 2: MuQ + Symbolic Fusion
    'F8_muq_symbolic_simple',
    'F9_muq_symbolic_weighted',
    'F10_muq_symbolic_ridge',
    'F11_muq_symbolic_confidence',
    # Part 3: MERT + MuQ Fusion
    'D9a_mert_muq_ensemble',
    'D9b_mert_muq_concat',
    'D9c_mert_muq_gated',
    # Part 4: Cross-Dataset Validation
    'X2_asap_multiperformer',
    'X3_psyllabus_difficulty',
    # Part 5: Statistics
    'S3_bootstrap_all',
    'S4_significance_tests',
    # Part 6: Analysis
    'A3_error_correlation',
    'A4_dimension_breakdown',
    'A5_failure_cases',
    'A6_calibration',
    'A7_gate_visualization',
]

print_experiment_status(EXPERIMENT_IDS, COMPLETED_CACHE)

---
## Part 1: MuQ Layer Ablation (M1a-M1d)

Find optimal MuQ layer range. MuQ has 12 transformer layers (hidden_states indices 0-12, where 0 is initial embedding).

In [None]:
# Cell 9: MuQ Layer Configurations
# MuQ has 13 hidden states (indices 0-12): index 0 is initial embedding, 1-12 are transformer layers
MUQ_LAYER_CONFIGS = {
    'M1a_muq_L1-4': {'layer_start': 1, 'layer_end': 5, 'desc': 'MuQ layers 1-4 (early acoustic)'},
    'M1b_muq_L5-8': {'layer_start': 5, 'layer_end': 9, 'desc': 'MuQ layers 5-8 (mid perceptual)'},
    'M1c_muq_L9-12': {'layer_start': 9, 'layer_end': 13, 'desc': 'MuQ layers 9-12 (late semantic)'},
    'M1d_muq_L1-12': {'layer_start': 1, 'layer_end': 13, 'desc': 'MuQ all layers 1-12'},
    'M2_muq_last_hidden': {'layer_start': None, 'layer_end': None, 'desc': 'MuQ last hidden state only (like D8)'},
}

# MuQ Stats pooling config (proven best in prior experiments)
MUQ_CONFIG = {
    **BASE_CONFIG,
    'input_dim': 1024,
    'hidden_dim': 512,
    'dropout': 0.2,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'pooling_stats': 'mean_std',  # 2x input dim
}

def make_muq_stats_model(cfg):
    return MuQStatsModel(
        input_dim=cfg['input_dim'],
        hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'],
        learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'],
        pooling_stats=cfg['pooling_stats'],
        max_epochs=cfg['max_epochs'],
    )

print("MuQ layer ablation configs ready")
for exp_id, cfg in MUQ_LAYER_CONFIGS.items():
    if cfg['layer_start'] is None:
        print(f"  {exp_id}: last hidden state")
    else:
        print(f"  {exp_id}: layers {cfg['layer_start']}-{cfg['layer_end']-1}")

In [None]:
# Cell 10: M1a - MuQ Layers 1-4
exp_id = 'M1a_muq_L1-4'
cfg = MUQ_LAYER_CONFIGS[exp_id]

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Create layer-specific cache
    cache_dir = MUQ_CACHE_ROOT / f"L{cfg['layer_start']}-{cfg['layer_end']-1}"
    cache_dir.mkdir(parents=True, exist_ok=True)
    
    # Extract embeddings
    extract_muq_embeddings(
        AUDIO_DIR, cache_dir, ALL_KEYS,
        layer_start=cfg['layer_start'],
        layer_end=cfg['layer_end']
    )
    
    # Train
    ALL_RESULTS[exp_id] = run_4fold_mert_experiment(
        exp_id=exp_id,
        description=cfg['desc'],
        model_factory=make_muq_stats_model,
        mert_cache_dir=cache_dir,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=MUQ_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 11: M1b - MuQ Layers 5-8
exp_id = 'M1b_muq_L5-8'
cfg = MUQ_LAYER_CONFIGS[exp_id]

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MUQ_CACHE_ROOT / f"L{cfg['layer_start']}-{cfg['layer_end']-1}"
    cache_dir.mkdir(parents=True, exist_ok=True)
    
    extract_muq_embeddings(
        AUDIO_DIR, cache_dir, ALL_KEYS,
        layer_start=cfg['layer_start'],
        layer_end=cfg['layer_end']
    )
    
    ALL_RESULTS[exp_id] = run_4fold_mert_experiment(
        exp_id=exp_id,
        description=cfg['desc'],
        model_factory=make_muq_stats_model,
        mert_cache_dir=cache_dir,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=MUQ_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 12: M1c - MuQ Layers 9-12
exp_id = 'M1c_muq_L9-12'
cfg = MUQ_LAYER_CONFIGS[exp_id]

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MUQ_CACHE_ROOT / f"L{cfg['layer_start']}-{cfg['layer_end']-1}"
    cache_dir.mkdir(parents=True, exist_ok=True)
    
    extract_muq_embeddings(
        AUDIO_DIR, cache_dir, ALL_KEYS,
        layer_start=cfg['layer_start'],
        layer_end=cfg['layer_end']
    )
    
    ALL_RESULTS[exp_id] = run_4fold_mert_experiment(
        exp_id=exp_id,
        description=cfg['desc'],
        model_factory=make_muq_stats_model,
        mert_cache_dir=cache_dir,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=MUQ_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 13: M1d - MuQ All Layers
exp_id = 'M1d_muq_L1-12'
cfg = MUQ_LAYER_CONFIGS[exp_id]

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MUQ_CACHE_ROOT / f"L{cfg['layer_start']}-{cfg['layer_end']-1}"
    cache_dir.mkdir(parents=True, exist_ok=True)
    
    extract_muq_embeddings(
        AUDIO_DIR, cache_dir, ALL_KEYS,
        layer_start=cfg['layer_start'],
        layer_end=cfg['layer_end']
    )
    
    ALL_RESULTS[exp_id] = run_4fold_mert_experiment(
        exp_id=exp_id,
        description=cfg['desc'],
        model_factory=make_muq_stats_model,
        mert_cache_dir=cache_dir,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=MUQ_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell: M2 - MuQ Last Hidden State (replicates D8 from phase 2)
# D8_muq_stats achieved R2=0.560 using last_hidden_state - let's verify this result
exp_id = 'M2_muq_last_hidden'
cfg = MUQ_LAYER_CONFIGS[exp_id]

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Create cache for last_hidden_state embeddings
    cache_dir = MUQ_CACHE_ROOT / 'last_hidden'
    cache_dir.mkdir(parents=True, exist_ok=True)
    
    # Try to download from GDrive first (D8 used this cache)
    try:
        run_rclone(['rclone', 'copy', GDRIVE_MUQ_CACHE, str(cache_dir)], "Downloading MuQ cache (last_hidden)")
    except RuntimeError as e:
        print(f"No existing cache on GDrive, will extract fresh: {e}")
    # Extract any missing (no layer range = last_hidden_state)
    extract_muq_embeddings(AUDIO_DIR, cache_dir, ALL_KEYS)
    
    # Train with stats pooling (same as D8)
    ALL_RESULTS[exp_id] = run_4fold_mert_experiment(
        exp_id=exp_id,
        description=cfg['desc'],
        model_factory=make_muq_stats_model,
        mert_cache_dir=cache_dir,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=MUQ_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 14: MuQ Layer Ablation Summary
print("\n" + "="*70)
print("MuQ LAYER ABLATION RESULTS")
print("="*70)
print(f"{'Experiment':<25} {'Layers':<12} {'R2':>10} {'Std':>10}")
print("-"*70)

best_muq_exp = None
best_muq_r2 = 0

for exp_id in ['M1a_muq_L1-4', 'M1b_muq_L5-8', 'M1c_muq_L9-12', 'M1d_muq_L1-12', 'M2_muq_last_hidden']:
    # Load from disk if not in memory
    if exp_id not in ALL_RESULTS:
        result_file = RESULTS_DIR / f"{exp_id}.json"
        if result_file.exists():
            with open(result_file) as f:
                ALL_RESULTS[exp_id] = json.load(f)
    
    if exp_id in ALL_RESULTS:
        r = ALL_RESULTS[exp_id]
        r2 = r['summary']['avg_r2']
        std = r['summary']['std_r2']
        cfg = MUQ_LAYER_CONFIGS[exp_id]
        layers = 'last_hidden' if cfg['layer_start'] is None else f"{cfg['layer_start']}-{cfg['layer_end']-1}"
        print(f"{exp_id:<25} {layers:<12} {r2:>10.4f} {std:>10.4f}")
        
        if r2 > best_muq_r2:
            best_muq_r2 = r2
            best_muq_exp = exp_id

print("-"*70)
if best_muq_exp:
    print(f"BEST: {best_muq_exp} (R2={best_muq_r2:.4f})")
    BEST_MUQ_CONFIG = MUQ_LAYER_CONFIGS[best_muq_exp]
    if BEST_MUQ_CONFIG['layer_start'] is None:
        BEST_MUQ_CACHE = MUQ_CACHE_ROOT / 'last_hidden'
    else:
        BEST_MUQ_CACHE = MUQ_CACHE_ROOT / f"L{BEST_MUQ_CONFIG['layer_start']}-{BEST_MUQ_CONFIG['layer_end']-1}"

---
## Part 2: MuQ + Symbolic Fusion (F8-F11)

In [None]:
# Cell 16: Load Symbolic Predictions
SYMBOLIC_PRED_FILE = DATA_ROOT / 'symbolic_predictions.json'
run_rclone(['rclone', 'copyto', GDRIVE_SYMBOLIC, str(SYMBOLIC_PRED_FILE)], "Downloading symbolic predictions")

with open(SYMBOLIC_PRED_FILE) as f:
    SYMBOLIC_PREDICTIONS = json.load(f)

print(f"Loaded symbolic predictions for {len(SYMBOLIC_PREDICTIONS)} samples")

In [None]:
# Cell 17: Generate MuQ Predictions
def generate_muq_predictions(checkpoint_dir: Path, cache_dir: Path, fold_assignments: Dict, labels: Dict) -> Dict[str, List[float]]:
    """Generate CV predictions from trained MuQ models."""
    from audio_experiments.data import MERTDataset, mert_collate_fn
    from torch.utils.data import DataLoader
    
    predictions = {}
    device = torch.device('cuda')
    
    for fold in range(4):
        ckpt_path = checkpoint_dir / f"fold{fold}_best.ckpt"
        if not ckpt_path.exists():
            print(f"Warning: checkpoint not found: {ckpt_path}")
            continue
        
        model = MuQStatsModel.load_from_checkpoint(ckpt_path)
        model = model.to(device).eval()
        
        # Get validation keys for this fold
        val_keys = fold_assignments.get(f"fold_{fold}", [])
        val_ds = MERTDataset(cache_dir, labels, fold_assignments, fold, "val", max_frames=1000)
        val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=mert_collate_fn)
        
        with torch.no_grad():
            for batch in val_dl:
                pred = model(batch['embeddings'].to(device), batch['attention_mask'].to(device))
                for key, p in zip(batch['keys'], pred.cpu().numpy()):
                    predictions[key] = p.tolist()
        
        del model
        torch.cuda.empty_cache()
    
    return predictions

# Generate predictions from best MuQ model
if best_muq_exp:
    print(f"Generating MuQ predictions from {best_muq_exp}...")
    MUQ_PREDICTIONS = generate_muq_predictions(
        CHECKPOINT_ROOT / best_muq_exp,
        BEST_MUQ_CACHE,
        FOLD_ASSIGNMENTS,
        LABELS
    )
    print(f"Generated predictions for {len(MUQ_PREDICTIONS)} samples")
else:
    print("WARNING: No MuQ model trained yet")
    MUQ_PREDICTIONS = {}

In [None]:
# Cell 18: Align predictions
# Find common keys
FUSION_KEYS = sorted(
    set(MUQ_PREDICTIONS.keys()) &
    set(SYMBOLIC_PREDICTIONS.keys()) &
    set(LABELS.keys())
)
print(f"Aligned samples: {len(FUSION_KEYS)}")

# Create aligned arrays
MUQ_ARR = np.array([MUQ_PREDICTIONS[k] for k in FUSION_KEYS])
SYMBOLIC_ARR = np.array([SYMBOLIC_PREDICTIONS[k] for k in FUSION_KEYS])
LABELS_ARR = np.array([LABELS[k][:19] for k in FUSION_KEYS])

print(f"MuQ shape: {MUQ_ARR.shape}")
print(f"Symbolic shape: {SYMBOLIC_ARR.shape}")
print(f"Labels shape: {LABELS_ARR.shape}")

In [None]:
# Cell 19: F8 - Simple Average Fusion
exp_id = 'F8_muq_symbolic_simple'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_simple_fusion_experiment(
        exp_id, MUQ_ARR, SYMBOLIC_ARR, LABELS_ARR, n_bootstrap=10000
    )
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 20: F9 - Weighted Fusion
exp_id = 'F9_muq_symbolic_weighted'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_weighted_fusion_experiment(
        exp_id, MUQ_ARR, SYMBOLIC_ARR, LABELS_ARR, FOLD_BY_KEY, FUSION_KEYS, n_bootstrap=10000
    )
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 21: F10 - Ridge Stacking
exp_id = 'F10_muq_symbolic_ridge'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_ridge_fusion_experiment(
        exp_id, MUQ_ARR, SYMBOLIC_ARR, LABELS_ARR, FOLD_BY_KEY, FUSION_KEYS, n_bootstrap=10000
    )
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 22: F11 - Confidence Weighted
exp_id = 'F11_muq_symbolic_confidence'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_confidence_fusion_experiment(
        exp_id, MUQ_ARR, SYMBOLIC_ARR, LABELS_ARR, n_bootstrap=10000
    )
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

---
## Part 3: MERT + MuQ Audio Fusion (D9a-D9c)

Test if two audio encoders provide complementary information.

In [None]:
# Cell 24: Extract MERT embeddings (layers 7-12, best from prior ablation)
MERT_CACHE = MERT_CACHE_ROOT / 'L7-12'
MERT_CACHE.mkdir(parents=True, exist_ok=True)

# Try to download from GDrive first
try:                                                                                                                     
    run_rclone(['rclone', 'copy', GDRIVE_MERT_CACHE, str(MERT_CACHE)], "Downloading MERT cache")                         
except RuntimeError as e:                                                                                                
    print(f"No existing cache on GDrive, will extract fresh: {e}")

# Extract any missing
extract_mert_for_layer_range(7, 13, AUDIO_DIR, MERT_CACHE, ALL_KEYS)
print(f"MERT embeddings ready: {len(list(MERT_CACHE.glob('*.pt')))} files")

In [None]:
# Cell 25: Ensure MuQ embeddings for D9 fusion experiments
# Use the best MuQ config - either from layer ablation or last_hidden_state (D8 approach)
if best_muq_exp:   
    if BEST_MUQ_CONFIG['layer_start'] is None:
        # Best is last_hidden_state
        MUQ_CACHE = MUQ_CACHE_ROOT / 'last_hidden'
        MUQ_CACHE.mkdir(parents=True, exist_ok=True)
        try:       
            run_rclone(['rclone', 'copy', GDRIVE_MUQ_CACHE, str(MUQ_CACHE)], "Downloading MuQ cache")
        except RuntimeError as e:
            print(f"No cache on GDrive (may already be extracted locally): {e}")
        extract_muq_embeddings(AUDIO_DIR, MUQ_CACHE, ALL_KEYS)
    else:
        # Best is a specific layer range
        MUQ_CACHE = BEST_MUQ_CACHE
        MUQ_CACHE.mkdir(parents=True, exist_ok=True)
        extract_muq_embeddings(
            AUDIO_DIR, MUQ_CACHE, ALL_KEYS,
            layer_start=BEST_MUQ_CONFIG['layer_start'],
            layer_end=BEST_MUQ_CONFIG['layer_end']
        )
else:   
    # No ablation results yet - use last_hidden_state as default (D8's approach)
    MUQ_CACHE = MUQ_CACHE_ROOT / 'last_hidden'
    MUQ_CACHE.mkdir(parents=True, exist_ok=True)
    try:
        run_rclone(['rclone', 'copy', GDRIVE_MUQ_CACHE, str(MUQ_CACHE)], "Downloading MuQ cache (last_hidden)")
    except RuntimeError as e:
        print(f"No cache on GDrive, will extract fresh: {e}")
    extract_muq_embeddings(AUDIO_DIR, MUQ_CACHE, ALL_KEYS)

print(f"MuQ cache: {MUQ_CACHE}")
print(f"MuQ embeddings: {len(list(MUQ_CACHE.glob('*.pt')))} files")

In [None]:
# Cell 26: D9a - MERT+MuQ Ensemble (Late Fusion)
exp_id = 'D9a_mert_muq_ensemble'

DUAL_CONFIG = {
    **BASE_CONFIG,
    'input_dim': 1024,
    'hidden_dim': 512,
    'dropout': 0.2,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'fusion_weight': 0.5,
}

def make_ensemble_model(cfg):
    return MERTMuQEnsemble(
        input_dim=cfg['input_dim'],
        hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'],
        learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'],
        pooling='attention',
        fusion_weight=cfg['fusion_weight'],
        max_epochs=cfg['max_epochs'],
    )

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_4fold_dual_experiment(
        exp_id=exp_id,
        description='MERT+MuQ late fusion ensemble',
        model_factory=make_ensemble_model,
        mert_cache_dir=MERT_CACHE,
        muq_cache_dir=MUQ_CACHE,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=DUAL_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 27: D9b - MERT+MuQ Concat (Early Fusion)
exp_id = 'D9b_mert_muq_concat'

def make_concat_model(cfg):
    return MERTMuQConcatModel(
        mert_dim=cfg['input_dim'],
        muq_dim=cfg['input_dim'],
        hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'],
        learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'],
        pooling='attention',
        max_epochs=cfg['max_epochs'],
    )

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_4fold_dual_experiment(
        exp_id=exp_id,
        description='MERT+MuQ early fusion concat',
        model_factory=make_concat_model,
        mert_cache_dir=MERT_CACHE,
        muq_cache_dir=MUQ_CACHE,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=DUAL_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 28: D9c - MERT+MuQ Gated Fusion
exp_id = 'D9c_mert_muq_gated'

def make_gated_model(cfg):
    return AsymmetricGatedFusion(
        mert_dim=cfg['input_dim'],
        muq_dim=cfg['input_dim'],
        mert_hidden=cfg['hidden_dim'],
        shared_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'],
        learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'],
        pooling='attention',
        max_epochs=cfg['max_epochs'],
    )

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    ALL_RESULTS[exp_id] = run_4fold_dual_experiment(
        exp_id=exp_id,
        description='MERT+MuQ asymmetric gated fusion',
        model_factory=make_gated_model,
        mert_cache_dir=MERT_CACHE,
        muq_cache_dir=MUQ_CACHE,
        labels=LABELS,
        fold_assignments=FOLD_ASSIGNMENTS,
        config=DUAL_CONFIG,
        checkpoint_root=CHECKPOINT_ROOT,
        results_dir=RESULTS_DIR,
        log_dir=LOG_DIR,
    )
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 29: Extract gate weights from D9c
exp_id = 'D9c_mert_muq_gated'
ckpt_path = CHECKPOINT_ROOT / exp_id / 'fold0_best.ckpt'

if ckpt_path.exists():
    from audio_experiments.data import DualEmbeddingDataset, dual_collate_fn
    from torch.utils.data import DataLoader
    
    model = AsymmetricGatedFusion.load_from_checkpoint(ckpt_path)
    model = model.to('cuda').eval()
    
    # Get sample batch for gate extraction
    val_keys = FOLD_ASSIGNMENTS.get('fold_0', [])[:32]
    ds = DualEmbeddingDataset(MERT_CACHE, MUQ_CACHE, LABELS, val_keys, max_frames=1000)
    dl = DataLoader(ds, batch_size=32, collate_fn=dual_collate_fn)
    batch = next(iter(dl))
    
    gate_info = model.get_learned_gates(
        batch['mert_embeddings'].cuda(),
        batch['muq_embeddings'].cuda(),
        batch['mert_mask'].cuda(),
        batch['muq_mask'].cuda(),
    )
    
    # Store gate weights per dimension
    GATE_WEIGHTS = {
        dim: float(gate_info['mert_weight_per_dim'][i])
        for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
    }
    
    print("\nLearned Gate Weights (higher = more MERT):")
    for dim, weight in sorted(GATE_WEIGHTS.items(), key=lambda x: -x[1]):
        print(f"  {dim:<25}: {weight:.3f}")
    
    # Save to results
    if exp_id in ALL_RESULTS:
        ALL_RESULTS[exp_id]['gate_weights'] = GATE_WEIGHTS
    
    del model
    torch.cuda.empty_cache()
else:
    print(f"Checkpoint not found: {ckpt_path}")
    GATE_WEIGHTS = {}

---
## Part 4: Cross-Dataset Validation (X2-X3)

Validate model generalization on external datasets:
- **X2**: ASAP Multi-Performer Analysis (variance across performers of same piece)
- **X3**: PSyllabus Difficulty Correlation (correlation with difficulty levels 1-11)

In [None]:
# Cell 32a: ASAP Dataset Setup
# Downloads ASAP metadata and links audio from MAESTRO v2.0.0
# This cell can take 1-2 hours due to MAESTRO download (~115GB)

import zipfile
import urllib.request
import shutil
from tqdm import tqdm

ASAP_REPO = ASAP_DIR / 'asap-dataset'
MAESTRO_DIR = ASAP_DIR / 'maestro-v2.0.0'
MAESTRO_ZIP = ASAP_DIR / 'maestro-v2.0.0.zip'
MAESTRO_URL = 'https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0.zip'

class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

def download_with_progress(url: str, output_path: Path, desc: str):
    """Download file with progress bar."""
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
        urllib.request.urlretrieve(url, output_path, reporthook=t.update_to)

# Step 1: Clone ASAP repository
if not ASAP_REPO.exists():
    print("Cloning ASAP repository...")
    result = subprocess.run(
        ['git', 'clone', 'https://github.com/fosfrancesco/asap-dataset.git', str(ASAP_REPO)],
        capture_output=True, text=True
    )
    if result.returncode != 0:
        raise RuntimeError(f"Failed to clone ASAP: {result.stderr}")
    print(f"Cloned to: {ASAP_REPO}")
else:
    print(f"ASAP repo exists: {ASAP_REPO}")

# Step 2: Download MAESTRO v2.0.0 if not exists
if not MAESTRO_DIR.exists():
    if not MAESTRO_ZIP.exists():
        print(f"\nDownloading MAESTRO v2.0.0 (~115GB)...")
        print("This will take a while. Go get some coffee.")
        download_with_progress(MAESTRO_URL, MAESTRO_ZIP, "MAESTRO v2.0.0")
    
    # Extract MAESTRO
    print(f"\nExtracting MAESTRO archive...")
    with zipfile.ZipFile(MAESTRO_ZIP, 'r') as zf:
        # Get total size for progress
        total_size = sum(f.file_size for f in zf.infolist())
        extracted_size = 0
        
        with tqdm(total=total_size, unit='B', unit_scale=True, desc="Extracting") as pbar:
            for member in zf.infolist():
                zf.extract(member, ASAP_DIR)
                extracted_size += member.file_size
                pbar.update(member.file_size)
    
    print(f"Extracted to: {MAESTRO_DIR}")
    
    # Clean up zip to save space
    print("Removing zip file to save space...")
    MAESTRO_ZIP.unlink()
else:
    print(f"MAESTRO exists: {MAESTRO_DIR}")

# Step 3: Run ASAP initialize_dataset.py to link audio
asap_metadata_csv = ASAP_REPO / 'metadata.csv'
asap_init_script = ASAP_REPO / 'initialize_dataset.py'

# Check if audio already linked by looking for wav files
existing_wavs = list(ASAP_REPO.rglob('*.wav'))
if len(existing_wavs) < 100:  # Expect ~500+ wav files if properly initialized
    print(f"\nLinking MAESTRO audio to ASAP ({len(existing_wavs)} wav files found)...")
    result = subprocess.run(
        ['python', str(asap_init_script), '-m', str(MAESTRO_DIR)],
        cwd=ASAP_REPO,
        capture_output=True, text=True
    )
    if result.returncode != 0:
        print(f"Warning: initialize_dataset.py returned error: {result.stderr}")
    
    # Recount wav files
    existing_wavs = list(ASAP_REPO.rglob('*.wav'))
    print(f"After initialization: {len(existing_wavs)} wav files")
else:
    print(f"ASAP audio already linked: {len(existing_wavs)} wav files")

# Step 4: Load and parse ASAP metadata
import pandas as pd

if asap_metadata_csv.exists():
    asap_df = pd.read_csv(asap_metadata_csv)
    print(f"\nASAP metadata loaded: {len(asap_df)} performances")
    
    # Filter to performances with audio
    asap_with_audio = asap_df[asap_df['audio_performance'].notna()]
    print(f"Performances with audio: {len(asap_with_audio)}")
    
    # Group by piece (using 'title' column) to find multi-performer pieces
    piece_counts = asap_with_audio.groupby('title').size()
    multi_performer_pieces = piece_counts[piece_counts >= 5]
    print(f"Pieces with 5+ performers: {len(multi_performer_pieces)}")
    
    # Store for use in X2 experiment
    ASAP_METADATA = asap_df
    ASAP_WITH_AUDIO = asap_with_audio
    ASAP_MULTI_PERFORMER = multi_performer_pieces
else:
    raise FileNotFoundError(f"ASAP metadata not found: {asap_metadata_csv}")

In [None]:
# Cell 32: X2 - ASAP Multi-Performer Analysis
# Analyzes variance in predictions across different performers playing the same piece
# If the model captures performance quality (not just piece characteristics),
# we expect meaningful variance across performers of the same piece.

exp_id = 'X2_asap_multiperformer'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    print("="*70)
    print("X2: ASAP MULTI-PERFORMER ANALYSIS")
    print("="*70)
    
    # Verify ASAP data is available
    if 'ASAP_WITH_AUDIO' not in dir() or 'ASAP_MULTI_PERFORMER' not in dir():
        raise RuntimeError("Run ASAP setup cell first (Cell 32a)")
    
    if len(ASAP_MULTI_PERFORMER) == 0:
        raise RuntimeError("No multi-performer pieces found in ASAP dataset")
    
    # Get best MuQ configuration from layer ablation
    if best_muq_exp is None:
        print("Warning: No best MuQ experiment found. Using default layers 1-12.")
        muq_layer_start, muq_layer_end = 1, 13
    else:
        muq_layer_start = BEST_MUQ_CONFIG['layer_start']
        muq_layer_end = BEST_MUQ_CONFIG['layer_end']
    
    print(f"Using MuQ layers {muq_layer_start}-{muq_layer_end-1}")
    
    # Setup cache directory for ASAP embeddings
    asap_muq_cache = ASAP_DIR / 'muq_cache' / f"L{muq_layer_start}-{muq_layer_end-1}"
    asap_muq_cache.mkdir(parents=True, exist_ok=True)
    
    # Load trained MuQ model (use fold 0 for inference)
    ckpt_path = CHECKPOINT_ROOT / best_muq_exp / 'fold0_best.ckpt' if best_muq_exp else None
    if ckpt_path is None or not ckpt_path.exists():
        # Try to find any available checkpoint
        for exp in ['M1d_muq_L1-12', 'M1c_muq_L9-12', 'M1b_muq_L5-8', 'M1a_muq_L1-4']:
            fallback_path = CHECKPOINT_ROOT / exp / 'fold0_best.ckpt'
            if fallback_path.exists():
                ckpt_path = fallback_path
                print(f"Using fallback checkpoint: {exp}")
                break
    
    if ckpt_path is None or not ckpt_path.exists():
        raise FileNotFoundError("No trained MuQ checkpoint found. Run MuQ experiments first.")
    
    model = MuQStatsModel.load_from_checkpoint(ckpt_path)
    model = model.to('cuda').eval()
    
    # Process each multi-performer piece
    piece_results = {}
    all_performances_processed = 0
    
    for piece_title in tqdm(ASAP_MULTI_PERFORMER.index, desc="Processing pieces"):
        # Get all performances of this piece with audio
        piece_perfs = ASAP_WITH_AUDIO[ASAP_WITH_AUDIO['title'] == piece_title]
        
        if len(piece_perfs) < 5:
            continue
        
        piece_predictions = []
        piece_keys = []
        
        for _, perf_row in piece_perfs.iterrows():
            # Construct audio path
            # ASAP stores audio_performance as relative path from repo root
            audio_rel_path = perf_row['audio_performance']
            if pd.isna(audio_rel_path):
                continue
            
            audio_path = ASAP_REPO / audio_rel_path
            if not audio_path.exists():
                # Try alternate path construction
                audio_path = ASAP_REPO / perf_row['folder'] / Path(audio_rel_path).name
            
            if not audio_path.exists():
                continue
            
            key = audio_path.stem
            emb_path = asap_muq_cache / f"{key}.pt"
            
            # Extract embedding if not cached
            if not emb_path.exists():
                try:
                    extract_muq_embeddings(
                        audio_path.parent, asap_muq_cache, [key],
                        layer_start=muq_layer_start, layer_end=muq_layer_end
                    )
                except Exception as e:
                    print(f"  Warning: Failed to extract {key}: {e}")
                    continue
            
            # Load embedding and predict
            if emb_path.exists():
                try:
                    with torch.no_grad():
                        emb = torch.load(emb_path).unsqueeze(0).cuda()
                        mask = torch.ones(1, emb.shape[1], dtype=torch.bool).cuda()
                        pred = model(emb, mask).cpu().numpy()[0]
                        piece_predictions.append(pred)
                        piece_keys.append(key)
                except Exception as e:
                    print(f"  Warning: Failed to predict {key}: {e}")
                    continue
        
        # Compute statistics for this piece
        if len(piece_predictions) >= 2:
            piece_predictions = np.array(piece_predictions)
            
            # Mean prediction across all performers
            mean_pred = piece_predictions.mean(axis=0)  # [19]
            
            # Standard deviation across performers for each dimension
            std_per_dim = piece_predictions.std(axis=0)  # [19]
            
            # Overall mean prediction (single scalar)
            mean_overall = float(mean_pred.mean())
            
            # Overall std (std of mean predictions across performers)
            std_overall = float(piece_predictions.mean(axis=1).std())
            
            piece_results[piece_title] = {
                'n_performances': len(piece_predictions),
                'performer_keys': piece_keys,
                'mean_pred_overall': mean_overall,
                'std_pred_overall': std_overall,
                'mean_pred_per_dim': mean_pred.tolist(),
                'std_pred_per_dim': std_per_dim.tolist(),
                'per_dim_analysis': {
                    dim: {
                        'mean': float(mean_pred[i]),
                        'std': float(std_per_dim[i]),
                    }
                    for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
                },
            }
            
            all_performances_processed += len(piece_predictions)
    
    # Compute aggregate statistics
    all_stds = [v['std_pred_overall'] for v in piece_results.values()]
    mean_intra_piece_std = np.mean(all_stds) if all_stds else 0
    median_intra_piece_std = np.median(all_stds) if all_stds else 0
    
    # Per-dimension analysis: which dimensions show most/least variance across performers?
    dim_variances = {dim: [] for dim in PERCEPIANO_DIMENSIONS}
    for piece_data in piece_results.values():
        for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
            dim_variances[dim].append(piece_data['std_pred_per_dim'][i])
    
    dim_mean_variance = {
        dim: float(np.mean(variances)) if variances else 0
        for dim, variances in dim_variances.items()
    }
    
    # Sort dimensions by variance (high to low)
    sorted_dims = sorted(dim_mean_variance.items(), key=lambda x: -x[1])
    
    ALL_RESULTS[exp_id] = {
        'exp_id': exp_id,
        'n_pieces': len(piece_results),
        'n_performances_total': all_performances_processed,
        'mean_intra_piece_std': float(mean_intra_piece_std),
        'median_intra_piece_std': float(median_intra_piece_std),
        'meaningful_variation': mean_intra_piece_std > 0.05,
        'high_variance_dimensions': [d for d, v in sorted_dims[:5]],
        'low_variance_dimensions': [d for d, v in sorted_dims[-5:]],
        'dimension_variance': dim_mean_variance,
        'piece_details': piece_results,
    }
    
    print(f"\nASAP Multi-Performer Results:")
    print(f"  Pieces analyzed: {len(piece_results)}")
    print(f"  Total performances: {all_performances_processed}")
    print(f"  Mean intra-piece std: {mean_intra_piece_std:.4f}")
    print(f"  Median intra-piece std: {median_intra_piece_std:.4f}")
    print(f"  Meaningful variation (std > 0.05): {ALL_RESULTS[exp_id]['meaningful_variation']}")
    print(f"\n  Highest variance dimensions:")
    for dim, var in sorted_dims[:5]:
        print(f"    {dim}: {var:.4f}")
    print(f"\n  Lowest variance dimensions:")
    for dim, var in sorted_dims[-5:]:
        print(f"    {dim}: {var:.4f}")
    
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)
    
    del model
    torch.cuda.empty_cache()
else:
    print(f"SKIP {exp_id}: already completed")

In [None]:
# Cell 33a: PSyllabus Dataset Setup
# Downloads metadata from Zenodo and audio from YouTube via yt-dlp
# This cell can take several hours due to YouTube rate limiting
# Target: ~50 samples per difficulty level (1-11) = ~550 total

import time
import random

# Install yt-dlp if not available
try:
    import yt_dlp
except ImportError:
    print("Installing yt-dlp...")
    subprocess.run(['pip', 'install', '-q', 'yt-dlp'], check=True)
    import yt_dlp

PSYLLABUS_METADATA_URL = 'https://zenodo.org/records/14794592/files/new_clean_data.json?download=1'
PSYLLABUS_METADATA_FILE = PSYLLABUS_DIR / 'new_clean_data.json'
PSYLLABUS_AUDIO_DIR = PSYLLABUS_DIR / 'audio'
PSYLLABUS_CHECKPOINT_FILE = PSYLLABUS_DIR / 'download_checkpoint.json'

PSYLLABUS_AUDIO_DIR.mkdir(parents=True, exist_ok=True)

# Step 1: Download metadata from Zenodo
if not PSYLLABUS_METADATA_FILE.exists():
    print("Downloading PSyllabus metadata from Zenodo...")
    download_with_progress(PSYLLABUS_METADATA_URL, PSYLLABUS_METADATA_FILE, "PSyllabus metadata")
else:
    print(f"PSyllabus metadata exists: {PSYLLABUS_METADATA_FILE}")

# Step 2: Load and parse metadata
with open(PSYLLABUS_METADATA_FILE) as f:
    psyllabus_raw = json.load(f)

# Parse entries - structure may vary, handle different formats
psyllabus_entries = []
if isinstance(psyllabus_raw, list):
    psyllabus_entries = psyllabus_raw
elif isinstance(psyllabus_raw, dict):
    # Could be keyed by ID or have a 'data' field
    if 'data' in psyllabus_raw:
        psyllabus_entries = psyllabus_raw['data']
    else:
        # Assume keys are IDs
        psyllabus_entries = [{'id': k, **v} for k, v in psyllabus_raw.items()]

print(f"PSyllabus entries loaded: {len(psyllabus_entries)}")

# Step 3: Extract YouTube URLs and difficulty levels
def extract_youtube_id(entry):
    """Extract YouTube video ID from various possible fields."""
    for field in ['youtube_id', 'video_id', 'url', 'youtube_url', 'link']:
        if field in entry:
            val = entry[field]
            if val:
                # Extract ID from URL if needed
                if 'youtube.com' in str(val) or 'youtu.be' in str(val):
                    if 'v=' in val:
                        return val.split('v=')[1].split('&')[0]
                    elif 'youtu.be/' in val:
                        return val.split('youtu.be/')[1].split('?')[0]
                else:
                    return str(val)
    return None

def extract_difficulty(entry):
    """Extract difficulty level from various possible fields."""
    for field in ['difficulty', 'level', 'grade', 'difficulty_level']:
        if field in entry:
            val = entry[field]
            if val is not None:
                try:
                    diff = int(val)
                    if 1 <= diff <= 11:
                        return diff
                except (ValueError, TypeError):
                    pass
    return None

# Parse entries
parsed_entries = []
for entry in psyllabus_entries:
    yt_id = extract_youtube_id(entry)
    difficulty = extract_difficulty(entry)
    
    if yt_id and difficulty:
        parsed_entries.append({
            'youtube_id': yt_id,
            'difficulty': difficulty,
            'composer': entry.get('composer', ''),
            'title': entry.get('title', entry.get('name', '')),
        })

print(f"Parsed entries with YouTube ID and difficulty: {len(parsed_entries)}")

# Group by difficulty level
by_difficulty = {i: [] for i in range(1, 12)}
for entry in parsed_entries:
    by_difficulty[entry['difficulty']].append(entry)

print("Distribution by difficulty level:")
for diff, entries in sorted(by_difficulty.items()):
    print(f"  Level {diff:2d}: {len(entries):4d} entries")

# Step 4: Stratified sampling - select ~50 per level
TARGET_PER_LEVEL = 50
sampled_entries = []

random.seed(42)  # Reproducible sampling
for diff in range(1, 12):
    available = by_difficulty[diff]
    if len(available) <= TARGET_PER_LEVEL:
        sampled_entries.extend(available)
    else:
        sampled_entries.extend(random.sample(available, TARGET_PER_LEVEL))

print(f"\nSampled {len(sampled_entries)} entries for download")

# Step 5: Download audio from YouTube
# Load checkpoint if exists (for resumability)
download_checkpoint = {'completed': [], 'failed': []}
if PSYLLABUS_CHECKPOINT_FILE.exists():
    with open(PSYLLABUS_CHECKPOINT_FILE) as f:
        download_checkpoint = json.load(f)
    print(f"Loaded checkpoint: {len(download_checkpoint['completed'])} completed, {len(download_checkpoint['failed'])} failed")

completed_ids = set(download_checkpoint['completed'])
failed_ids = set(download_checkpoint['failed'])

# Filter to entries not yet attempted
to_download = [e for e in sampled_entries if e['youtube_id'] not in completed_ids and e['youtube_id'] not in failed_ids]
print(f"Entries to download: {len(to_download)}")

def download_youtube_audio(youtube_id: str, output_dir: Path, max_retries: int = 3) -> bool:
    """Download audio from YouTube using yt-dlp."""
    output_template = str(output_dir / f"{youtube_id}.%(ext)s")
    
    ydl_opts = {
        'format': 'bestaudio/best',
        'extractaudio': True,
        'audioformat': 'wav',
        'outtmpl': output_template,
        'quiet': True,
        'no_warnings': True,
        'postprocessors': [{
            'key': 'FFmpegExtractAudio',
            'preferredcodec': 'wav',
            'preferredquality': '192',
        }],
    }
    
    for attempt in range(max_retries):
        try:
            with yt_dlp.YoutubeDL(ydl_opts) as ydl:
                ydl.download([f'https://www.youtube.com/watch?v={youtube_id}'])
            
            # Verify file exists
            wav_path = output_dir / f"{youtube_id}.wav"
            if wav_path.exists() and wav_path.stat().st_size > 1000:
                return True
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff
            continue
    
    return False

# Download with rate limiting
newly_completed = 0
newly_failed = 0

for i, entry in enumerate(tqdm(to_download, desc="Downloading YouTube audio")):
    yt_id = entry['youtube_id']
    
    # Check if already downloaded
    wav_path = PSYLLABUS_AUDIO_DIR / f"{yt_id}.wav"
    if wav_path.exists() and wav_path.stat().st_size > 1000:
        download_checkpoint['completed'].append(yt_id)
        newly_completed += 1
        continue
    
    # Download
    success = download_youtube_audio(yt_id, PSYLLABUS_AUDIO_DIR)
    
    if success:
        download_checkpoint['completed'].append(yt_id)
        newly_completed += 1
    else:
        download_checkpoint['failed'].append(yt_id)
        newly_failed += 1
    
    # Save checkpoint every 10 downloads
    if (i + 1) % 10 == 0:
        with open(PSYLLABUS_CHECKPOINT_FILE, 'w') as f:
            json.dump(download_checkpoint, f)
    
    # Rate limiting: random delay between downloads
    time.sleep(random.uniform(2.0, 4.0))

# Final checkpoint save
with open(PSYLLABUS_CHECKPOINT_FILE, 'w') as f:
    json.dump(download_checkpoint, f)

# Summary
print(f"\nDownload Summary:")
print(f"  Newly completed: {newly_completed}")
print(f"  Newly failed: {newly_failed}")
print(f"  Total completed: {len(download_checkpoint['completed'])}")
print(f"  Total failed: {len(download_checkpoint['failed'])}")

# Count successful downloads by difficulty
successful_by_diff = {i: 0 for i in range(1, 12)}
for entry in sampled_entries:
    wav_path = PSYLLABUS_AUDIO_DIR / f"{entry['youtube_id']}.wav"
    if wav_path.exists() and wav_path.stat().st_size > 1000:
        successful_by_diff[entry['difficulty']] += 1

print("\nSuccessful downloads by difficulty:")
for diff, count in sorted(successful_by_diff.items()):
    print(f"  Level {diff:2d}: {count:3d}")

# Store for use in X3 experiment
PSYLLABUS_SAMPLED = sampled_entries
PSYLLABUS_SUCCESS_BY_DIFF = successful_by_diff

In [None]:
# Cell 33: X3 - PSyllabus Difficulty Correlation
# Correlates model predictions with ground-truth difficulty levels (1-11)
# Difficulty is a proxy for required skill level, so we expect weak-to-moderate
# positive correlation (rho > 0.2) if the model captures performance quality.

exp_id = 'X3_psyllabus_difficulty'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    print("="*70)
    print("X3: PSYLLABUS DIFFICULTY CORRELATION")
    print("="*70)
    
    # Verify PSyllabus data is available
    if 'PSYLLABUS_SAMPLED' not in dir():
        raise RuntimeError("Run PSyllabus setup cell first (Cell 33a)")
    
    # Get best MuQ configuration
    if best_muq_exp is None:
        print("Warning: No best MuQ experiment found. Using default layers 1-12.")
        muq_layer_start, muq_layer_end = 1, 13
    else:
        muq_layer_start = BEST_MUQ_CONFIG['layer_start']
        muq_layer_end = BEST_MUQ_CONFIG['layer_end']
    
    print(f"Using MuQ layers {muq_layer_start}-{muq_layer_end-1}")
    
    # Setup cache directory
    psyllabus_muq_cache = PSYLLABUS_DIR / 'muq_cache' / f"L{muq_layer_start}-{muq_layer_end-1}"
    psyllabus_muq_cache.mkdir(parents=True, exist_ok=True)
    
    # Load trained MuQ model
    ckpt_path = CHECKPOINT_ROOT / best_muq_exp / 'fold0_best.ckpt' if best_muq_exp else None
    if ckpt_path is None or not ckpt_path.exists():
        for exp in ['M1d_muq_L1-12', 'M1c_muq_L9-12', 'M1b_muq_L5-8', 'M1a_muq_L1-4']:
            fallback_path = CHECKPOINT_ROOT / exp / 'fold0_best.ckpt'
            if fallback_path.exists():
                ckpt_path = fallback_path
                print(f"Using fallback checkpoint: {exp}")
                break
    
    if ckpt_path is None or not ckpt_path.exists():
        raise FileNotFoundError("No trained MuQ checkpoint found. Run MuQ experiments first.")
    
    model = MuQStatsModel.load_from_checkpoint(ckpt_path)
    model = model.to('cuda').eval()
    
    # Process each downloaded audio file
    difficulties = []
    predictions = []
    prediction_details = []
    
    for entry in tqdm(PSYLLABUS_SAMPLED, desc="Processing PSyllabus audio"):
        yt_id = entry['youtube_id']
        difficulty = entry['difficulty']
        
        # Check if audio exists
        wav_path = PSYLLABUS_AUDIO_DIR / f"{yt_id}.wav"
        if not wav_path.exists() or wav_path.stat().st_size < 1000:
            continue
        
        key = yt_id
        emb_path = psyllabus_muq_cache / f"{key}.pt"
        
        # Extract embedding if not cached
        if not emb_path.exists():
            try:
                extract_muq_embeddings(
                    PSYLLABUS_AUDIO_DIR, psyllabus_muq_cache, [key],
                    layer_start=muq_layer_start, layer_end=muq_layer_end
                )
            except Exception as e:
                print(f"  Warning: Failed to extract {key}: {e}")
                continue
        
        # Load embedding and predict
        if emb_path.exists():
            try:
                with torch.no_grad():
                    emb = torch.load(emb_path).unsqueeze(0).cuda()
                    mask = torch.ones(1, emb.shape[1], dtype=torch.bool).cuda()
                    pred = model(emb, mask).cpu().numpy()[0]
                    
                    # Store results
                    mean_pred = float(pred.mean())
                    predictions.append(mean_pred)
                    difficulties.append(difficulty)
                    
                    prediction_details.append({
                        'youtube_id': yt_id,
                        'difficulty': difficulty,
                        'mean_prediction': mean_pred,
                        'per_dim_prediction': pred.tolist(),
                        'composer': entry.get('composer', ''),
                        'title': entry.get('title', ''),
                    })
            except Exception as e:
                print(f"  Warning: Failed to predict {key}: {e}")
                continue
    
    print(f"\nProcessed {len(predictions)} samples")
    
    if len(predictions) < 10:
        print("ERROR: Insufficient data for correlation analysis (need at least 10 samples)")
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'error': 'Insufficient data',
            'n_samples': len(predictions),
        }
    else:
        # Compute Spearman correlation
        rho, p_value = stats.spearmanr(difficulties, predictions)
        
        # Compute correlation by difficulty range (low/mid/high)
        low_mask = np.array(difficulties) <= 4
        mid_mask = (np.array(difficulties) >= 5) & (np.array(difficulties) <= 7)
        high_mask = np.array(difficulties) >= 8
        
        range_analysis = {}
        for name, mask in [('low_1-4', low_mask), ('mid_5-7', mid_mask), ('high_8-11', high_mask)]:
            if mask.sum() >= 5:
                range_preds = np.array(predictions)[mask]
                range_diffs = np.array(difficulties)[mask]
                if len(set(range_diffs)) > 1:  # Need variance for correlation
                    r, p = stats.spearmanr(range_diffs, range_preds)
                    range_analysis[name] = {
                        'n': int(mask.sum()),
                        'rho': float(r),
                        'p_value': float(p),
                        'mean_prediction': float(range_preds.mean()),
                    }
                else:
                    range_analysis[name] = {
                        'n': int(mask.sum()),
                        'mean_prediction': float(range_preds.mean()),
                    }
        
        # Per-dimension correlation analysis
        all_preds = np.array([d['per_dim_prediction'] for d in prediction_details])
        all_diffs = np.array(difficulties)
        
        per_dim_correlation = {}
        for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
            dim_preds = all_preds[:, i]
            r, p = stats.spearmanr(all_diffs, dim_preds)
            per_dim_correlation[dim] = {
                'rho': float(r),
                'p_value': float(p),
                'significant': p < 0.05,
            }
        
        # Sort dimensions by correlation strength
        sorted_dims = sorted(per_dim_correlation.items(), key=lambda x: -abs(x[1]['rho']))
        
        # Mean prediction by difficulty level
        mean_by_difficulty = {}
        for diff in range(1, 12):
            diff_mask = np.array(difficulties) == diff
            if diff_mask.sum() > 0:
                mean_by_difficulty[diff] = float(np.array(predictions)[diff_mask].mean())
        
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'n_samples': len(predictions),
            'spearman_rho': float(rho),
            'p_value': float(p_value),
            'significant': p_value < 0.05,
            'weak_positive': rho > 0.2,
            'moderate_positive': rho > 0.4,
            'difficulty_range': [int(min(difficulties)), int(max(difficulties))],
            'mean_by_difficulty': mean_by_difficulty,
            'range_analysis': range_analysis,
            'per_dimension_correlation': per_dim_correlation,
            'strongest_correlating_dims': [d for d, v in sorted_dims[:5]],
            'sample_details': prediction_details[:20],  # Store first 20 for inspection
        }
        
        print(f"\nPSyllabus Difficulty Correlation Results:")
        print(f"  Samples: {len(predictions)}")
        print(f"  Difficulty range: {min(difficulties)}-{max(difficulties)}")
        print(f"  Spearman rho: {rho:.4f}")
        print(f"  p-value: {p_value:.2e}")
        print(f"  Significant (p < 0.05): {p_value < 0.05}")
        print(f"  Weak positive (rho > 0.2): {rho > 0.2}")
        print(f"  Moderate positive (rho > 0.4): {rho > 0.4}")
        
        print(f"\n  Mean prediction by difficulty level:")
        for diff in sorted(mean_by_difficulty.keys()):
            print(f"    Level {diff:2d}: {mean_by_difficulty[diff]:.4f}")
        
        print(f"\n  Strongest correlating dimensions:")
        for dim, corr in sorted_dims[:5]:
            print(f"    {dim}: rho={corr['rho']:.4f}, p={corr['p_value']:.2e}")
        
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)
    
    del model
    torch.cuda.empty_cache()
else:
    print(f"SKIP {exp_id}: already completed")

---
## Part 5: Statistical Rigor (S3-S4)

Bootstrap CIs and significance tests for all comparisons.

In [None]:
# Cell 35: S3 - Bootstrap CIs for all comparisons
exp_id = 'S3_bootstrap_all'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    bootstrap_results = {}
    
    # MuQ vs Symbolic
    if len(MUQ_ARR) > 0 and len(SYMBOLIC_ARR) > 0:
        print("Computing MuQ vs Symbolic bootstrap...")
        bootstrap_results['muq_vs_symbolic'] = bootstrap_r2_comparison(
            LABELS_ARR, MUQ_ARR, SYMBOLIC_ARR, n_bootstrap=10000
        )
        print(f"  MuQ: {bootstrap_results['muq_vs_symbolic']['r2_a']:.4f}")
        print(f"  Symbolic: {bootstrap_results['muq_vs_symbolic']['r2_b']:.4f}")
        print(f"  Diff: {bootstrap_results['muq_vs_symbolic']['difference']:.4f}")
        print(f"  MuQ significantly better: {bootstrap_results['muq_vs_symbolic']['a_significantly_better']}")
    
    # MuQ CIs
    if len(MUQ_ARR) > 0:
        print("\nComputing MuQ bootstrap CIs...")
        bootstrap_results['muq_ci'] = bootstrap_r2_extended(LABELS_ARR, MUQ_ARR, n_bootstrap=10000)
        print(f"  R2: {bootstrap_results['muq_ci']['overall']['r2']:.4f}")
        print(f"  95% CI: [{bootstrap_results['muq_ci']['overall']['ci_lower']:.4f}, {bootstrap_results['muq_ci']['overall']['ci_upper']:.4f}]")
    
    ALL_RESULTS[exp_id] = {
        'exp_id': exp_id,
        'n_bootstrap': 10000,
        'comparisons': bootstrap_results,
    }
    
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 36: S4 - Significance Tests
exp_id = 'S4_significance_tests'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    significance_results = {}
    
    if len(MUQ_ARR) > 0 and len(SYMBOLIC_ARR) > 0:
        # Paired t-test
        ttest = paired_ttest_per_sample(LABELS_ARR, MUQ_ARR, SYMBOLIC_ARR)
        significance_results['paired_ttest'] = ttest
        print(f"Paired t-test: t={ttest['t_stat']:.4f}, p={ttest['p_value']:.2e}")
        
        # Wilcoxon
        wilcox = wilcoxon_test(LABELS_ARR, MUQ_ARR, SYMBOLIC_ARR)
        significance_results['wilcoxon'] = wilcox
        print(f"Wilcoxon: stat={wilcox['stat']:.4f}, p={wilcox['p_value']:.2e}")
        
        # Cohen's d
        d = cohens_d(LABELS_ARR, MUQ_ARR, SYMBOLIC_ARR)
        significance_results['cohens_d'] = d
        print(f"Cohen's d: {d:.4f}")
        
        # Per-dimension tests with Bonferroni
        per_dim_p = []
        for i in range(19):
            t = paired_ttest_per_sample(
                LABELS_ARR[:, i:i+1],
                MUQ_ARR[:, i:i+1],
                SYMBOLIC_ARR[:, i:i+1]
            )
            per_dim_p.append(t['p_value'])
        
        bonf_corrected, bonf_sig = bonferroni_correction(np.array(per_dim_p))
        significance_results['per_dimension'] = {
            dim: {
                'raw_p': per_dim_p[i],
                'corrected_p': float(bonf_corrected[i]),
                'significant': bool(bonf_sig[i]),
            }
            for i, dim in enumerate(PERCEPIANO_DIMENSIONS)
        }
        
        print(f"\nBonferroni correction: {sum(bonf_sig)}/19 dimensions significant")
    
    ALL_RESULTS[exp_id] = {
        'exp_id': exp_id,
        'tests': significance_results,
    }
    
    save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
    sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

---
## Part 6: Analysis (A3-A7)

In [None]:
# Cell 38: A3 - Error Correlation Analysis
exp_id = 'A3_error_correlation'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    if len(MUQ_ARR) > 0 and len(SYMBOLIC_ARR) > 0:
        ALL_RESULTS[exp_id] = run_error_correlation_experiment(
            exp_id, MUQ_ARR, SYMBOLIC_ARR, LABELS_ARR
        )
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 39: A4 - Per-Dimension Breakdown
exp_id = 'A4_dimension_breakdown'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    if len(MUQ_ARR) > 0 and len(SYMBOLIC_ARR) > 0:
        fused = simple_average_fusion(MUQ_ARR, SYMBOLIC_ARR)
        
        dim_comparison = {}
        for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
            muq_r2 = r2_score(LABELS_ARR[:, i], MUQ_ARR[:, i])
            symbolic_r2 = r2_score(LABELS_ARR[:, i], SYMBOLIC_ARR[:, i])
            fusion_r2 = r2_score(LABELS_ARR[:, i], fused[:, i])
            
            # Determine category
            category = None
            for cat, dims in DIMENSION_CATEGORIES.items():
                if dim in dims:
                    category = cat
                    break
            
            dim_comparison[dim] = {
                'muq_r2': float(muq_r2),
                'symbolic_r2': float(symbolic_r2),
                'fusion_r2': float(fusion_r2),
                'winner': 'muq' if muq_r2 > symbolic_r2 else 'symbolic',
                'muq_advantage': float(muq_r2 - symbolic_r2),
                'category': category,
            }
        
        # Count winners by category
        category_summary = {}
        for cat in DIMENSION_CATEGORIES:
            cat_dims = [d for d, v in dim_comparison.items() if v['category'] == cat]
            muq_wins = sum(1 for d in cat_dims if dim_comparison[d]['winner'] == 'muq')
            category_summary[cat] = {
                'total': len(cat_dims),
                'muq_wins': muq_wins,
                'symbolic_wins': len(cat_dims) - muq_wins,
            }
        
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'per_dimension': dim_comparison,
            'category_summary': category_summary,
            'muq_total_wins': sum(1 for v in dim_comparison.values() if v['winner'] == 'muq'),
        }
        
        print(f"\nDimension Breakdown:")
        print(f"  MuQ wins: {ALL_RESULTS[exp_id]['muq_total_wins']}/19 dimensions")
        
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 40: A5 - Failure Cases
exp_id = 'A5_failure_cases'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    if len(MUQ_ARR) > 0:
        # Compute per-sample MSE
        mse_per_sample = ((LABELS_ARR - MUQ_ARR) ** 2).mean(axis=1)
        
        # Find worst predictions
        worst_indices = np.argsort(mse_per_sample)[-10:]
        
        failure_cases = []
        for idx in worst_indices:
            key = FUSION_KEYS[idx]
            sample_mse = mse_per_sample[idx]
            
            # Find worst dimensions for this sample
            dim_errors = np.abs(LABELS_ARR[idx] - MUQ_ARR[idx])
            worst_dims = np.argsort(dim_errors)[-3:]
            
            failure_cases.append({
                'key': key,
                'mse': float(sample_mse),
                'worst_dimensions': [PERCEPIANO_DIMENSIONS[i] for i in worst_dims],
                'predicted': MUQ_ARR[idx].tolist(),
                'actual': LABELS_ARR[idx].tolist(),
            })
        
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'n_samples': len(mse_per_sample),
            'mean_mse': float(mse_per_sample.mean()),
            'max_mse': float(mse_per_sample.max()),
            'failure_cases': failure_cases,
        }
        
        print(f"\nFailure Case Analysis:")
        print(f"  Mean MSE: {ALL_RESULTS[exp_id]['mean_mse']:.4f}")
        print(f"  Max MSE: {ALL_RESULTS[exp_id]['max_mse']:.4f}")
        print(f"  Worst samples: {[f['key'] for f in failure_cases[:3]]}")
        
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 41: A6 - Calibration
exp_id = 'A6_calibration'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    if len(MUQ_ARR) > 0:
        # Compute calibration by decile
        n_bins = 10
        calibration = []
        
        # Flatten for overall calibration
        preds_flat = MUQ_ARR.flatten()
        labels_flat = LABELS_ARR.flatten()
        
        # Bin by predicted values
        bins = np.linspace(0, 1, n_bins + 1)
        for i in range(n_bins):
            mask = (preds_flat >= bins[i]) & (preds_flat < bins[i+1])
            if mask.sum() > 0:
                calibration.append({
                    'bin': i,
                    'bin_range': [float(bins[i]), float(bins[i+1])],
                    'count': int(mask.sum()),
                    'mean_predicted': float(preds_flat[mask].mean()),
                    'mean_actual': float(labels_flat[mask].mean()),
                    'error': float(preds_flat[mask].mean() - labels_flat[mask].mean()),
                })
        
        # Dispersion ratio
        pred_std = MUQ_ARR.std()
        label_std = LABELS_ARR.std()
        dispersion_ratio = pred_std / label_std if label_std > 0 else 0
        
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'calibration_bins': calibration,
            'dispersion_ratio': float(dispersion_ratio),
            'pred_std': float(pred_std),
            'label_std': float(label_std),
        }
        
        print(f"\nCalibration Analysis:")
        print(f"  Dispersion ratio: {dispersion_ratio:.4f}")
        print(f"  Prediction std: {pred_std:.4f}")
        print(f"  Label std: {label_std:.4f}")
        
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)

In [None]:
# Cell 42: A7 - Gate Weight Visualization
exp_id = 'A7_gate_visualization'

if should_run_experiment(exp_id, CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    if GATE_WEIGHTS:
        # Sort by MERT preference
        sorted_dims = sorted(GATE_WEIGHTS.items(), key=lambda x: -x[1])
        
        # Group by category
        category_gates = {}
        for cat, dims in DIMENSION_CATEGORIES.items():
            cat_weights = [GATE_WEIGHTS.get(d, 0.5) for d in dims]
            category_gates[cat] = {
                'mean_mert_weight': float(np.mean(cat_weights)),
                'dimensions': {d: GATE_WEIGHTS.get(d, 0.5) for d in dims},
            }
        
        ALL_RESULTS[exp_id] = {
            'exp_id': exp_id,
            'gate_weights': GATE_WEIGHTS,
            'mert_preferred_dims': [d for d, w in sorted_dims[:5]],
            'muq_preferred_dims': [d for d, w in sorted_dims[-5:]],
            'category_summary': category_gates,
            'mean_gate': float(np.mean(list(GATE_WEIGHTS.values()))),
        }
        
        print(f"\nGate Weight Analysis:")
        print(f"  Mean gate (0.5=balanced): {ALL_RESULTS[exp_id]['mean_gate']:.3f}")
        print(f"  MERT-preferred: {ALL_RESULTS[exp_id]['mert_preferred_dims']}")
        print(f"  MuQ-preferred: {ALL_RESULTS[exp_id]['muq_preferred_dims']}")
        
        save_fusion_experiment(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, ALL_RESULTS)
        sync_experiment_to_gdrive(exp_id, ALL_RESULTS[exp_id], RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS)
    else:
        print("No gate weights available (D9c not trained)")

---
## Part 7: Results Export

In [None]:
# Cell 44: Export all results
print("\n" + "="*70)
print("EXPORTING RESULTS")
print("="*70)

# Load any missing results from disk
for exp_id in EXPERIMENT_IDS:
    if exp_id not in ALL_RESULTS:
        result_file = RESULTS_DIR / f"{exp_id}.json"
        if result_file.exists():
            with open(result_file) as f:
                ALL_RESULTS[exp_id] = json.load(f)

# Save aggregate results
aggregate_file = RESULTS_DIR / 'definitive_all_results.json'
with open(aggregate_file, 'w') as f:
    json.dump(ALL_RESULTS, f, indent=2, default=numpy_serializer)
print(f"Saved: {aggregate_file}")

# Sync to GDrive
run_rclone(['rclone', 'copy', str(RESULTS_DIR), GDRIVE_RESULTS], "Syncing results to GDrive")
print(f"Synced to: {GDRIVE_RESULTS}")

In [None]:
# Cell 45: Final Summary
print("\n" + "="*70)
print("DEFINITIVE EXPERIMENTS SUMMARY")
print("="*70)

# Part 1: MuQ Layer Ablation
print("\nPart 1: MuQ Layer Ablation")
print("-"*40)
for exp_id in ['M1a_muq_L1-4', 'M1b_muq_L5-8', 'M1c_muq_L9-12', 'M1d_muq_L1-12', 'M2_muq_last_hidden']:
    if exp_id in ALL_RESULTS and 'summary' in ALL_RESULTS[exp_id]:
        r2 = ALL_RESULTS[exp_id]['summary']['avg_r2']
        print(f"  {exp_id}: R2={r2:.4f}")

# Part 2: MuQ + Symbolic Fusion
print("\nPart 2: MuQ + Symbolic Fusion")
print("-"*40)
for exp_id in ['F8_muq_symbolic_simple', 'F9_muq_symbolic_weighted', 'F10_muq_symbolic_ridge', 'F11_muq_symbolic_confidence']:
    if exp_id in ALL_RESULTS and 'overall_r2' in ALL_RESULTS[exp_id]:
        r2 = ALL_RESULTS[exp_id]['overall_r2']
        print(f"  {exp_id}: R2={r2:.4f}")

# Part 3: MERT + MuQ Fusion
print("\nPart 3: MERT + MuQ Audio Fusion")
print("-"*40)
for exp_id in ['D9a_mert_muq_ensemble', 'D9b_mert_muq_concat', 'D9c_mert_muq_gated']:
    if exp_id in ALL_RESULTS and 'summary' in ALL_RESULTS[exp_id]:
        r2 = ALL_RESULTS[exp_id]['summary']['avg_r2']
        print(f"  {exp_id}: R2={r2:.4f}")

# Part 4: Cross-Dataset Validation
print("\nPart 4: Cross-Dataset Validation")
print("-"*40)
if 'X2_asap_multiperformer' in ALL_RESULTS:
    r = ALL_RESULTS['X2_asap_multiperformer']
    print(f"  ASAP Multi-Performer:")
    print(f"    Pieces analyzed: {r.get('n_pieces', 0)}")
    print(f"    Intra-piece std: {r.get('mean_intra_piece_std', 0):.4f}")
    print(f"    Meaningful variation: {r.get('meaningful_variation', False)}")
if 'X3_psyllabus_difficulty' in ALL_RESULTS:
    r = ALL_RESULTS['X3_psyllabus_difficulty']
    print(f"  PSyllabus Difficulty:")
    print(f"    Samples: {r.get('n_samples', 0)}")
    print(f"    Spearman rho: {r.get('spearman_rho', 0):.4f}")
    print(f"    Significant: {r.get('significant', False)}")

# Completion stats
completed = sum(1 for e in EXPERIMENT_IDS if e in ALL_RESULTS)
print(f"\nCompleted: {completed}/{len(EXPERIMENT_IDS)} experiments")
print("="*70)