# Phase 2: Audio Baseline Experiments

Comprehensive experiments for the ISMIR paper.

## Experiments
- **B0**: Baseline re-run (MERT+MLP, L13-24, mean pool)
- **A1-A3**: Baselines (linear probe, Mel-CNN, raw statistics)
- **B1a-B1d**: Layer ablation (1-6, 7-12, 13-24, 1-24)
- **B2a-B2c**: Pooling ablation (max, attention, LSTM)
- **C1a-C1b**: Loss ablation (hybrid MSE+CCC, pure CCC)

## Requirements
- Compute: A100 (80GB VRAM)
- rclone configured with `gdrive:` remote

In [None]:
# Set CUDA deterministic mode (must be before any CUDA operations)
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("GPU required")

In [None]:
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install dependencies and clone repo
!pip install transformers librosa soundfile pytorch_lightning nnAudio --quiet

# Clone the repo
import os
REPO_DIR = '/tmp/crescendai'
if os.path.exists(REPO_DIR):
    !cd {REPO_DIR} && git pull origin main
else:
    !git clone https://github.com/jai-dhiman/crescendai.git {REPO_DIR}

print(f"Repo: {REPO_DIR}")

In [None]:
# 3. Setup imports
import sys
sys.path.insert(0, f'{REPO_DIR}/model/src')

import json
import subprocess
import warnings
from pathlib import Path

import numpy as np
import pytorch_lightning as pl

# Import from our package
from audio_experiments import PERCEPIANO_DIMENSIONS, BASE_CONFIG, SEED
from audio_experiments.extractors import (
    extract_mert_for_layer_range,
    extract_mel_spectrograms,
    extract_statistics_for_all,
)
from audio_experiments.models import BaseMERTModel, LinearProbeModel, MelCNNModel, StatsMLPModel
from audio_experiments.training import (
    run_4fold_mert_experiment,
    run_4fold_mel_experiment,
    run_4fold_stats_experiment,
    restore_all_from_gdrive,
    should_run_experiment,
    sync_experiment_to_gdrive,
    print_experiment_status,
)

warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('medium')
pl.seed_everything(SEED, workers=True)

print(f"PyTorch: {torch.__version__}")
print(f"Imports: OK")

In [None]:
# 4. Setup paths and download data
DATA_ROOT = Path('/tmp/phase2')
AUDIO_DIR = DATA_ROOT / 'audio'
LABEL_DIR = DATA_ROOT / 'labels'
MERT_CACHE_ROOT = DATA_ROOT / 'mert_cache'
MEL_CACHE_DIR = DATA_ROOT / 'mel_cache'
STATS_CACHE_DIR = DATA_ROOT / 'stats_cache'
CHECKPOINT_ROOT = DATA_ROOT / 'checkpoints'
RESULTS_DIR = DATA_ROOT / 'results'
LOG_DIR = DATA_ROOT / 'logs'

GDRIVE_AUDIO = 'gdrive:crescendai_data/audio_baseline/percepiano_rendered'
GDRIVE_LABELS = 'gdrive:crescendai_data/percepiano_labels'
GDRIVE_FOLDS = 'gdrive:crescendai_data/audio_baseline/audio_fold_assignments.json'
GDRIVE_MERT_CACHE = 'gdrive:crescendai_data/audio_baseline/mert_embeddings'
GDRIVE_RESULTS = 'gdrive:crescendai_data/checkpoints/audio_phase2'

for d in [AUDIO_DIR, LABEL_DIR, MERT_CACHE_ROOT, MEL_CACHE_DIR, STATS_CACHE_DIR,
          CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def run_rclone(cmd, desc):
    print(f"{desc}...")
    subprocess.run(cmd, capture_output=True)

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone 'gdrive' not configured")

# Download data
run_rclone(['rclone', 'copy', GDRIVE_AUDIO, str(AUDIO_DIR), '--progress'], "Downloading audio")
run_rclone(['rclone', 'copy', GDRIVE_LABELS, str(LABEL_DIR)], "Downloading labels")

FOLD_FILE = DATA_ROOT / 'folds.json'
run_rclone(['rclone', 'copyto', GDRIVE_FOLDS, str(FOLD_FILE)], "Downloading folds")

# Load labels and folds
LABEL_FILE = LABEL_DIR / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
with open(LABEL_FILE) as f:
    LABELS = json.load(f)
with open(FOLD_FILE) as f:
    FOLD_ASSIGNMENTS = json.load(f)

ALL_KEYS = list(LABELS.keys())
print(f"Audio: {len(list(AUDIO_DIR.glob('*.wav')))} files")
print(f"Labels: {len(LABELS)} segments")

In [None]:
# 5. Restore MERT cache and completed experiments from GDrive
DEFAULT_MERT_DIR = MERT_CACHE_ROOT / 'L13-24'
DEFAULT_MERT_DIR.mkdir(parents=True, exist_ok=True)

result = subprocess.run(['rclone', 'lsf', GDRIVE_MERT_CACHE], capture_output=True, text=True)
if result.returncode == 0 and '.pt' in result.stdout:
    print("Restoring MERT cache...")
    run_rclone(['rclone', 'copy', GDRIVE_MERT_CACHE, str(DEFAULT_MERT_DIR)], "Restoring cache")
    print(f"Restored: {len(list(DEFAULT_MERT_DIR.glob('*.pt')))} embeddings")

ALL_RESULTS = {}

ALL_EXPERIMENT_IDS = [
    'B0_baseline', 'A1_linear_probe', 'A2_mel_cnn', 'A3_raw_stats',
    'B1a_layers_1-6', 'B1b_layers_7-12', 'B1c_layers_13-24', 'B1d_layers_1-24',
    'B2a_max_pool', 'B2b_attention_pool', 'B2c_lstm_pool',
    'C1a_hybrid_loss', 'C1b_pure_ccc',
]

print("\nChecking GDrive for completed experiments...")
restored = restore_all_from_gdrive(
    GDRIVE_RESULTS,
    RESULTS_DIR,
    CHECKPOINT_ROOT,
    ALL_RESULTS,
)

# Cache completed experiments to avoid repeated GDrive calls
from audio_experiments.training import get_completed_experiments
COMPLETED_CACHE = get_completed_experiments(GDRIVE_RESULTS)

print_experiment_status(ALL_EXPERIMENT_IDS, COMPLETED_CACHE)

---
## Experiments

In [None]:
# B0: Baseline
if should_run_experiment('B0_baseline', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_mert_model(cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg.get('pooling', 'mean'),
            loss_type=cfg.get('loss_type', 'mse'), max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B0_baseline'] = run_4fold_mert_experiment(
        'B0_baseline', 'MERT+MLP, L13-24, mean pooling',
        make_mert_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B0_baseline', ALL_RESULTS['B0_baseline'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# A1: Linear Probe
if should_run_experiment('A1_linear_probe', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure embeddings exist (reuses B0's extraction if already done)
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_linear_probe(cfg):
        return LinearProbeModel(
            input_dim=cfg['input_dim'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['A1_linear_probe'] = run_4fold_mert_experiment(
        'A1_linear_probe', 'Linear probe on MERT',
        make_linear_probe, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'A1_linear_probe', ALL_RESULTS['A1_linear_probe'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# A2: Mel-CNN
if should_run_experiment('A2_mel_cnn', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    extract_mel_spectrograms(AUDIO_DIR, MEL_CACHE_DIR, ALL_KEYS)

    ALL_RESULTS['A2_mel_cnn'] = run_4fold_mel_experiment(
        'A2_mel_cnn', '4-layer CNN on mel spectrograms',
        MEL_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'A2_mel_cnn', ALL_RESULTS['A2_mel_cnn'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# A3: Raw Statistics
if should_run_experiment('A3_raw_stats', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    extract_statistics_for_all(AUDIO_DIR, STATS_CACHE_DIR, ALL_KEYS)

    ALL_RESULTS['A3_raw_stats'] = run_4fold_stats_experiment(
        'A3_raw_stats', 'MLP on audio statistics (49-dim)',
        STATS_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'A3_raw_stats', ALL_RESULTS['A3_raw_stats'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B1a: Layer Ablation - Early Layers (1-6)
if should_run_experiment('B1a_layers_1-6', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MERT_CACHE_ROOT / 'L1-6'
    extract_mert_for_layer_range(1, 7, AUDIO_DIR, cache_dir, ALL_KEYS)

    def make_mert_model(cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg.get('pooling', 'mean'),
            loss_type=cfg.get('loss_type', 'mse'), max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B1a_layers_1-6'] = run_4fold_mert_experiment(
        'B1a_layers_1-6', 'MERT layers 1-6 (early)',
        make_mert_model, cache_dir, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B1a_layers_1-6', ALL_RESULTS['B1a_layers_1-6'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B1b: Layer Ablation - Mid Layers (7-12)
if should_run_experiment('B1b_layers_7-12', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MERT_CACHE_ROOT / 'L7-12'
    extract_mert_for_layer_range(7, 13, AUDIO_DIR, cache_dir, ALL_KEYS)

    def make_mert_model(cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg.get('pooling', 'mean'),
            loss_type=cfg.get('loss_type', 'mse'), max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B1b_layers_7-12'] = run_4fold_mert_experiment(
        'B1b_layers_7-12', 'MERT layers 7-12 (mid)',
        make_mert_model, cache_dir, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B1b_layers_7-12', ALL_RESULTS['B1b_layers_7-12'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B1c: Layer Ablation - Late Layers (13-24)
if should_run_experiment('B1c_layers_13-24', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MERT_CACHE_ROOT / 'L13-24'
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, cache_dir, ALL_KEYS)

    def make_mert_model(cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg.get('pooling', 'mean'),
            loss_type=cfg.get('loss_type', 'mse'), max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B1c_layers_13-24'] = run_4fold_mert_experiment(
        'B1c_layers_13-24', 'MERT layers 13-24 (late)',
        make_mert_model, cache_dir, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B1c_layers_13-24', ALL_RESULTS['B1c_layers_13-24'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B1d: Layer Ablation - All Layers (1-24)
if should_run_experiment('B1d_layers_1-24', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    cache_dir = MERT_CACHE_ROOT / 'L1-24'
    extract_mert_for_layer_range(1, 25, AUDIO_DIR, cache_dir, ALL_KEYS)

    def make_mert_model(cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg.get('pooling', 'mean'),
            loss_type=cfg.get('loss_type', 'mse'), max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B1d_layers_1-24'] = run_4fold_mert_experiment(
        'B1d_layers_1-24', 'MERT all layers 1-24',
        make_mert_model, cache_dir, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B1d_layers_1-24', ALL_RESULTS['B1d_layers_1-24'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B2a: Pooling Ablation - Max Pooling
if should_run_experiment('B2a_max_pool', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure L13-24 embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['pooling'] = 'max'

    def make_max_pool_model(cfg=cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg['pooling'],
            loss_type='mse', max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B2a_max_pool'] = run_4fold_mert_experiment(
        'B2a_max_pool', 'MERT + max pooling',
        make_max_pool_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B2a_max_pool', ALL_RESULTS['B2a_max_pool'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B2b: Pooling Ablation - Attention Pooling
if should_run_experiment('B2b_attention_pool', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure L13-24 embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['pooling'] = 'attention'

    def make_attention_pool_model(cfg=cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg['pooling'],
            loss_type='mse', max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B2b_attention_pool'] = run_4fold_mert_experiment(
        'B2b_attention_pool', 'MERT + attention pooling',
        make_attention_pool_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B2b_attention_pool', ALL_RESULTS['B2b_attention_pool'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# B2c: Pooling Ablation - Bi-LSTM Pooling
if should_run_experiment('B2c_lstm_pool', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure L13-24 embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['pooling'] = 'lstm'

    def make_lstm_pool_model(cfg=cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling=cfg['pooling'],
            loss_type='mse', max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['B2c_lstm_pool'] = run_4fold_mert_experiment(
        'B2c_lstm_pool', 'MERT + Bi-LSTM pooling',
        make_lstm_pool_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'B2c_lstm_pool', ALL_RESULTS['B2c_lstm_pool'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# C1a: Loss Ablation - Hybrid MSE + CCC Loss
if should_run_experiment('C1a_hybrid_loss', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure L13-24 embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['loss_type'] = 'hybrid'

    def make_hybrid_loss_model(cfg=cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling='mean',
            loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['C1a_hybrid_loss'] = run_4fold_mert_experiment(
        'C1a_hybrid_loss', 'MERT + MSE + 0.5*CCC loss',
        make_hybrid_loss_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'C1a_hybrid_loss', ALL_RESULTS['C1a_hybrid_loss'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# C1b: Loss Ablation - Pure CCC Loss
if should_run_experiment('C1b_pure_ccc', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, COMPLETED_CACHE):
    # Ensure L13-24 embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['loss_type'] = 'ccc'

    def make_ccc_loss_model(cfg=cfg):
        return BaseMERTModel(
            input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'], pooling='mean',
            loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['C1b_pure_ccc'] = run_4fold_mert_experiment(
        'C1b_pure_ccc', 'MERT + pure CCC loss',
        make_ccc_loss_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'C1b_pure_ccc', ALL_RESULTS['C1b_pure_ccc'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

---
## Results Summary

In [None]:
# Print results table
print("="*80)
print("PHASE 2 RESULTS SUMMARY")
print("="*80)

baseline_r2 = ALL_RESULTS.get('B0_baseline', {}).get('summary', {}).get('avg_r2', 0)

print(f"{'Experiment':<25} {'Avg R2':>10} {'95% CI':>20} {'vs B0':>10} {'Disp':>8}")
print("-"*80)

exp_order = [
    'B0_baseline', None,
    'A1_linear_probe', 'A2_mel_cnn', 'A3_raw_stats', None,
    'B1a_layers_1-6', 'B1b_layers_7-12', 'B1c_layers_13-24', 'B1d_layers_1-24', None,
    'B2a_max_pool', 'B2b_attention_pool', 'B2c_lstm_pool', None,
    'C1a_hybrid_loss', 'C1b_pure_ccc',
]

for exp_id in exp_order:
    if exp_id is None:
        print("-"*80)
        continue
    if exp_id not in ALL_RESULTS:
        continue

    r = ALL_RESULTS[exp_id]
    s = r['summary']
    ci = s.get('r2_ci_95', [0, 0])
    diff = s['avg_r2'] - baseline_r2 if exp_id != 'B0_baseline' else 0
    diff_str = f"{diff:+.3f}" if exp_id != 'B0_baseline' else '---'

    print(f"{exp_id:<25} {s['avg_r2']:>10.4f} [{ci[0]:.3f}, {ci[1]:.3f}] {diff_str:>10} {s.get('dispersion_ratio', 0):>8.2f}")

print("="*80)

In [None]:
# Safety sync
with open(RESULTS_DIR / 'phase2_all_results.json', 'w') as f:
    json.dump(ALL_RESULTS, f, indent=2)

print("Final sync to Google Drive...")
run_rclone(['rclone', 'copy', str(RESULTS_DIR), GDRIVE_RESULTS], "Syncing results")
run_rclone(['rclone', 'copy', str(CHECKPOINT_ROOT), f"{GDRIVE_RESULTS}/checkpoints"], "Syncing checkpoints")

print_experiment_status(ALL_EXPERIMENT_IDS, {k: v['summary']['avg_r2'] for k, v in ALL_RESULTS.items()})
print("Done! Results at:", GDRIVE_RESULTS)

---
## Phase 3: Advanced Architecture Experiments

Based on research recommendations for improving audio-only R2 toward 0.50+:
- **D1a/D1b**: Statistical pooling (mean+std, mean+std+min+max)
- **D2a/D2b**: Uncertainty-weighted loss (mean pool, attention pool)
- **D3**: Dimension-specific heads (BiLSTM for timing, MLP for rest)
- **D4**: Multi-layer MERT concat ([6,9,12])
- **D5**: Transformer pooling (2-layer encoder before attention pool)
- **D6**: Multi-scale temporal pooling

Expected gains: +0.02-0.05 R2 cumulative from best configurations.

In [None]:
# Phase 3 model imports
from audio_experiments.models import (
    StatsPoolingModel,
    UncertaintyWeightedModel,
    DimensionSpecificModel,
    TransformerPoolingModel,
    MultiScalePoolingModel,
    MultiLayerMERTModel,
)
from audio_experiments.extractors import extract_mert_multilayer_concat

# Phase 3 experiment IDs
PHASE3_EXPERIMENT_IDS = [
    'D1a_stats_mean_std', 'D1b_stats_full',
    'D2a_uncertainty_mean', 'D2b_uncertainty_attn',
    'D3_dimension_heads',
    'D4_multilayer_6_9_12',
    'D5_transformer_pool',
    'D6_multiscale_pool',
]

# Extend ALL_EXPERIMENT_IDS
ALL_EXPERIMENT_IDS.extend(PHASE3_EXPERIMENT_IDS)

# Check for completed experiments
PHASE3_COMPLETED = get_completed_experiments(GDRIVE_RESULTS)
print_experiment_status(PHASE3_EXPERIMENT_IDS, PHASE3_COMPLETED)

In [None]:
# D1a: Statistical Pooling (mean + std)
if should_run_experiment('D1a_stats_mean_std', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_stats_model(cfg):
        return StatsPoolingModel(
            input_dim=1024,
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling_stats='mean_std',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D1a_stats_mean_std'] = run_4fold_mert_experiment(
        'D1a_stats_mean_std', 'MERT + stats pooling (mean+std)',
        make_stats_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D1a_stats_mean_std', ALL_RESULTS['D1a_stats_mean_std'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D1b: Statistical Pooling (mean + std + min + max)
if should_run_experiment('D1b_stats_full', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_stats_full_model(cfg):
        return StatsPoolingModel(
            input_dim=1024,
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling_stats='mean_std_min_max',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D1b_stats_full'] = run_4fold_mert_experiment(
        'D1b_stats_full', 'MERT + stats pooling (mean+std+min+max)',
        make_stats_full_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D1b_stats_full', ALL_RESULTS['D1b_stats_full'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D2a: Uncertainty-Weighted Loss (mean pooling)
if should_run_experiment('D2a_uncertainty_mean', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_uncertainty_model(cfg):
        return UncertaintyWeightedModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling='mean',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D2a_uncertainty_mean'] = run_4fold_mert_experiment(
        'D2a_uncertainty_mean', 'MERT + uncertainty-weighted loss (mean pool)',
        make_uncertainty_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D2a_uncertainty_mean', ALL_RESULTS['D2a_uncertainty_mean'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D2b: Uncertainty-Weighted Loss (attention pooling)
if should_run_experiment('D2b_uncertainty_attn', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_uncertainty_attn_model(cfg):
        return UncertaintyWeightedModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling='attention',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D2b_uncertainty_attn'] = run_4fold_mert_experiment(
        'D2b_uncertainty_attn', 'MERT + uncertainty-weighted loss (attention pool)',
        make_uncertainty_attn_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D2b_uncertainty_attn', ALL_RESULTS['D2b_uncertainty_attn'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D3: Dimension-Specific Heads (BiLSTM for timing, MLP for rest)
if should_run_experiment('D3_dimension_heads', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_dimension_heads_model(cfg):
        return DimensionSpecificModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            lstm_hidden=256,
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D3_dimension_heads'] = run_4fold_mert_experiment(
        'D3_dimension_heads', 'MERT + dimension-specific heads (BiLSTM timing, MLP rest)',
        make_dimension_heads_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D3_dimension_heads', ALL_RESULTS['D3_dimension_heads'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D4: Multi-Layer MERT Concat [6, 9, 12]
if should_run_experiment('D4_multilayer_6_9_12', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    multilayer_cache = MERT_CACHE_ROOT / 'L6-9-12-concat'
    extract_mert_multilayer_concat([6, 9, 12], AUDIO_DIR, multilayer_cache, ALL_KEYS)

    cfg = BASE_CONFIG.copy()
    cfg['input_dim'] = 1024 * 3  # 3 layers concatenated = 3072

    def make_multilayer_model(cfg=cfg):
        return MultiLayerMERTModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling='attention',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D4_multilayer_6_9_12'] = run_4fold_mert_experiment(
        'D4_multilayer_6_9_12', 'MERT concat layers [6,9,12] + attention pool',
        make_multilayer_model, multilayer_cache, LABELS, FOLD_ASSIGNMENTS,
        cfg, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D4_multilayer_6_9_12', ALL_RESULTS['D4_multilayer_6_9_12'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D5: Transformer Pooling (2-layer encoder before attention pool)
if should_run_experiment('D5_transformer_pool', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_transformer_model(cfg):
        return TransformerPoolingModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            num_heads=8,
            num_layers=2,
            pooling='attention',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D5_transformer_pool'] = run_4fold_mert_experiment(
        'D5_transformer_pool', 'MERT + 2-layer transformer + attention pool',
        make_transformer_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D5_transformer_pool', ALL_RESULTS['D5_transformer_pool'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D6: Multi-Scale Temporal Pooling
if should_run_experiment('D6_multiscale_pool', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE3_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_multiscale_model(cfg):
        return MultiScalePoolingModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            scales=(4, 8, 16, 32),
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D6_multiscale_pool'] = run_4fold_mert_experiment(
        'D6_multiscale_pool', 'MERT + multi-scale pooling (4,8,16,32 frames)',
        make_multiscale_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D6_multiscale_pool', ALL_RESULTS['D6_multiscale_pool'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

---
## Phase 3 Results Summary

In [None]:
# Phase 3 Results Summary
print("="*80)
print("PHASE 3 RESULTS SUMMARY")
print("="*80)

# Get Phase 2 baseline for comparison
baseline_r2 = ALL_RESULTS.get('B0_baseline', {}).get('summary', {}).get('avg_r2', 0)
best_phase2 = max(
    ALL_RESULTS.get('B1b_layers_7-12', {}).get('summary', {}).get('avg_r2', 0),
    ALL_RESULTS.get('B2b_attention_pool', {}).get('summary', {}).get('avg_r2', 0),
    baseline_r2
)

print(f"\nPhase 2 Baseline (B0): {baseline_r2:.4f}")
print(f"Best Phase 2: {best_phase2:.4f}")
print()

print(f"{'Experiment':<25} {'Avg R2':>10} {'95% CI':>20} {'vs Best P2':>12}")
print("-"*75)

for exp_id in PHASE3_EXPERIMENT_IDS:
    if exp_id not in ALL_RESULTS:
        result_file = RESULTS_DIR / f'{exp_id}.json'
        if result_file.exists():
            with open(result_file) as f:
                ALL_RESULTS[exp_id] = json.load(f)
        else:
            continue

    r = ALL_RESULTS[exp_id]
    s = r['summary']
    ci = s.get('r2_ci_95', [0, 0])
    diff = s['avg_r2'] - best_phase2
    print(f"{exp_id:<25} {s['avg_r2']:>10.4f} [{ci[0]:.3f}, {ci[1]:.3f}] {diff:>+12.4f}")

print("="*75)

# Find best Phase 3 experiment
phase3_results = [(ALL_RESULTS.get(exp_id, {}).get('summary', {}).get('avg_r2', 0), exp_id)
                  for exp_id in PHASE3_EXPERIMENT_IDS if exp_id in ALL_RESULTS]
if phase3_results:
    best_p3 = max(phase3_results)
    print(f"\nBest Phase 3: {best_p3[1]} (R2={best_p3[0]:.4f})")
    print(f"Improvement over Phase 2: {best_p3[0] - best_phase2:+.4f}")

In [None]:
# Per-Dimension Analysis for Best Phase 3 Model
from audio_experiments import DIMENSION_CATEGORIES

phase3_results = [(ALL_RESULTS.get(exp_id, {}).get('summary', {}).get('avg_r2', 0), exp_id)
                  for exp_id in PHASE3_EXPERIMENT_IDS if exp_id in ALL_RESULTS]

if phase3_results:
    best_p3 = max(phase3_results)
    best_exp = ALL_RESULTS[best_p3[1]]
    per_dim = best_exp.get('per_dimension', {})

    print(f"\nPer-Dimension R2 for {best_p3[1]}")
    print("-"*50)

    for category, dims in DIMENSION_CATEGORIES.items():
        cat_r2s = [per_dim.get(d, {}).get('r2', 0) for d in dims]
        cat_avg = np.mean(cat_r2s) if cat_r2s else 0
        print(f"\n{category.upper()} (avg: {cat_avg:.3f})")
        for dim in dims:
            r2 = per_dim.get(dim, {}).get('r2', 0)
            print(f"  {dim:<25} {r2:.4f}")

    print("\n" + "="*50)
    print("TIMING DIMENSIONS (key target for improvement)")
    print("="*50)
    timing_r2 = per_dim.get('timing', {}).get('r2', 0)
    tempo_r2 = per_dim.get('tempo', {}).get('r2', 0)
    print(f"timing: {timing_r2:.4f}")
    print(f"tempo:  {tempo_r2:.4f}")

In [None]:
# Final sync all results (Phase 2 + Phase 3)
all_results_combined = {}
for exp_id in ALL_EXPERIMENT_IDS:
    if exp_id in ALL_RESULTS:
        all_results_combined[exp_id] = ALL_RESULTS[exp_id]

with open(RESULTS_DIR / 'all_results_combined.json', 'w') as f:
    json.dump(all_results_combined, f, indent=2)

print("Final sync to Google Drive...")
run_rclone(['rclone', 'copy', str(RESULTS_DIR), GDRIVE_RESULTS], "Syncing results")
run_rclone(['rclone', 'copy', str(CHECKPOINT_ROOT), f"{GDRIVE_RESULTS}/checkpoints"], "Syncing checkpoints")

print_experiment_status(ALL_EXPERIMENT_IDS, {k: v['summary']['avg_r2'] for k, v in ALL_RESULTS.items() if 'summary' in v})
print("\nDone! Results at:", GDRIVE_RESULTS)

---
## Phase 3.5: MuQ Experiments

MuQ (Music Understanding Quantized) is an alternative music representation model from ByteDance/OpenMuQ.
Similar to MERT but trained with different objectives, potentially capturing complementary features.

- **D7**: MuQ baseline (mean pooling)
- **D8**: MuQ with stats pooling (mean+std)
- **D9**: MERT+MuQ ensemble (average predictions)
- **D9b**: MERT+MuQ early fusion (concatenate embeddings)

Expected gains: R2 += 0.03-0.05 from ensemble/fusion.

In [None]:
# MuQ imports and setup
from audio_experiments.extractors import MuQExtractor, extract_muq_embeddings
from audio_experiments.models import MuQBaseModel, MuQStatsModel, MERTMuQEnsemble, MERTMuQConcatModel

# MuQ cache directory
MUQ_CACHE_DIR = DATA_ROOT / 'muq_cache'
MUQ_CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Phase 3.5 experiment IDs
PHASE35_EXPERIMENT_IDS = [
    'D7_muq_baseline',
    'D8_muq_stats',
    'D9_mert_muq_ensemble',
    'D9b_mert_muq_concat',
]

ALL_EXPERIMENT_IDS.extend(PHASE35_EXPERIMENT_IDS)
PHASE35_COMPLETED = get_completed_experiments(GDRIVE_RESULTS)
print_experiment_status(PHASE35_EXPERIMENT_IDS, PHASE35_COMPLETED)

In [None]:
# D7: MuQ Baseline (mean pooling)
if should_run_experiment('D7_muq_baseline', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE35_COMPLETED):
    # Extract MuQ embeddings (using last hidden state)
    extract_muq_embeddings(AUDIO_DIR, MUQ_CACHE_DIR, ALL_KEYS)

    def make_muq_baseline(cfg):
        return MuQBaseModel(
            input_dim=1024,  # MuQ hidden size
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling='mean',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D7_muq_baseline'] = run_4fold_mert_experiment(
        'D7_muq_baseline', 'MuQ baseline with mean pooling',
        make_muq_baseline, MUQ_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D7_muq_baseline', ALL_RESULTS['D7_muq_baseline'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D8: MuQ with Stats Pooling (mean + std)
if should_run_experiment('D8_muq_stats', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE35_COMPLETED):
    # Ensure MuQ embeddings exist
    extract_muq_embeddings(AUDIO_DIR, MUQ_CACHE_DIR, ALL_KEYS)

    def make_muq_stats(cfg):
        return MuQStatsModel(
            input_dim=1024,
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            pooling_stats='mean_std',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D8_muq_stats'] = run_4fold_mert_experiment(
        'D8_muq_stats', 'MuQ with stats pooling (mean+std)',
        make_muq_stats, MUQ_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D8_muq_stats', ALL_RESULTS['D8_muq_stats'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D9: MERT+MuQ Ensemble (late fusion - average predictions)
# Note: This requires a custom training loop since we need both MERT and MuQ embeddings

if should_run_experiment('D9_mert_muq_ensemble', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE35_COMPLETED):
    from audio_experiments.data import DualEmbeddingDataset, dual_collate_fn
    from torch.utils.data import DataLoader
    
    # Ensure both embeddings exist
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)
    extract_muq_embeddings(AUDIO_DIR, MUQ_CACHE_DIR, ALL_KEYS)
    
    def run_ensemble_experiment():
        from sklearn.metrics import r2_score
        import numpy as np
        
        all_preds, all_labels = [], []
        fold_metrics = []
        
        for fold in range(4):
            print(f"\nFold {fold + 1}/4")
            val_keys = [k for k, f in FOLD_ASSIGNMENTS.items() if f == fold]
            train_keys = [k for k, f in FOLD_ASSIGNMENTS.items() if f != fold]
            
            # Create dual-embedding datasets
            train_ds = DualEmbeddingDataset(
                DEFAULT_MERT_DIR, MUQ_CACHE_DIR, LABELS, train_keys,
                max_frames=BASE_CONFIG['max_frames']
            )
            val_ds = DualEmbeddingDataset(
                DEFAULT_MERT_DIR, MUQ_CACHE_DIR, LABELS, val_keys,
                max_frames=BASE_CONFIG['max_frames']
            )
            
            train_dl = DataLoader(
                train_ds, batch_size=BASE_CONFIG['batch_size'],
                shuffle=True, num_workers=2, collate_fn=dual_collate_fn
            )
            val_dl = DataLoader(
                val_ds, batch_size=BASE_CONFIG['batch_size'],
                shuffle=False, num_workers=2, collate_fn=dual_collate_fn
            )
            
            # Create model
            model = MERTMuQEnsemble(
                input_dim=1024,
                hidden_dim=BASE_CONFIG['hidden_dim'],
                dropout=BASE_CONFIG['dropout'],
                learning_rate=BASE_CONFIG['learning_rate'],
                weight_decay=BASE_CONFIG['weight_decay'],
                pooling='attention',
                fusion_weight=0.5,
                max_epochs=BASE_CONFIG['max_epochs'],
            )
            
            # Setup trainer
            ckpt_dir = CHECKPOINT_ROOT / 'D9_mert_muq_ensemble' / f'fold{fold}'
            ckpt_dir.mkdir(parents=True, exist_ok=True)
            
            trainer = pl.Trainer(
                max_epochs=BASE_CONFIG['max_epochs'],
                callbacks=[
                    pl.callbacks.ModelCheckpoint(
                        dirpath=ckpt_dir, filename='best',
                        monitor='val_r2', mode='max', save_top_k=1
                    ),
                    pl.callbacks.EarlyStopping(
                        monitor='val_r2', mode='max',
                        patience=BASE_CONFIG['patience']
                    ),
                ],
                logger=pl.loggers.CSVLogger(LOG_DIR, name='D9_mert_muq_ensemble', version=f'fold{fold}'),
                accelerator='auto', devices=1,
                gradient_clip_val=BASE_CONFIG['gradient_clip_val'],
                enable_progress_bar=True, deterministic=True,
            )
            
            trainer.fit(model, train_dl, val_dl)
            
            # Load best and evaluate
            best_path = list(ckpt_dir.glob('best*.ckpt'))[0]
            model = MERTMuQEnsemble.load_from_checkpoint(best_path)
            model.eval()
            
            preds, labels = [], []
            for batch in val_dl:
                batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                with torch.no_grad():
                    pred = model(
                        batch['mert_embeddings'], batch['muq_embeddings'],
                        batch.get('mert_mask'), batch.get('muq_mask')
                    )
                preds.append(pred.cpu())
                labels.append(batch['labels'].cpu())
            
            preds = torch.cat(preds).numpy()
            labels = torch.cat(labels).numpy()
            fold_r2 = r2_score(labels, preds)
            
            print(f"Fold {fold + 1} R2: {fold_r2:.4f}")
            fold_metrics.append({'fold': fold, 'r2': fold_r2})
            all_preds.extend(preds)
            all_labels.extend(labels)
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        avg_r2 = r2_score(all_labels, all_preds)
        
        return {
            'summary': {
                'avg_r2': avg_r2,
                'fold_r2s': [m['r2'] for m in fold_metrics],
            },
            'description': 'MERT+MuQ late fusion ensemble',
        }
    
    ALL_RESULTS['D9_mert_muq_ensemble'] = run_ensemble_experiment()
    sync_experiment_to_gdrive(
        'D9_mert_muq_ensemble', ALL_RESULTS['D9_mert_muq_ensemble'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

---
## Phase 3.6: Contrastive Auxiliary Loss Experiments

Adds a contrastive learning objective as auxiliary loss during training. 
The idea is to learn representations where performances with similar ratings are closer together in embedding space.

- **D10a**: Contrastive with lambda=0.05 (light regularization)
- **D10b**: Contrastive with lambda=0.1 (moderate regularization)
- **D10c**: Contrastive with lambda=0.2 (strong regularization)
- **D10d**: Contrastive warmup (lambda decays from 0.5 to 0.05)

Expected gains: R2 += 0.02-0.04 from improved representation structure.

In [None]:
# Contrastive model imports and setup
from audio_experiments.models import ContrastiveAuxiliaryModel, ContrastiveWarmupModel

# Phase 3.6 experiment IDs
PHASE36_EXPERIMENT_IDS = [
    'D10a_contrastive_0.05',
    'D10b_contrastive_0.1',
    'D10c_contrastive_0.2',
    'D10d_contrastive_warmup',
]

ALL_EXPERIMENT_IDS.extend(PHASE36_EXPERIMENT_IDS)
PHASE36_COMPLETED = get_completed_experiments(GDRIVE_RESULTS)
print_experiment_status(PHASE36_EXPERIMENT_IDS, PHASE36_COMPLETED)

In [None]:
# D10a: Contrastive Auxiliary Loss (lambda=0.05)
if should_run_experiment('D10a_contrastive_0.05', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE36_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_contrastive_model_005(cfg):
        return ContrastiveAuxiliaryModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            contrastive_lambda=0.05,
            temperature=0.07,
            pooling='attention',
            contrastive_type='supervised',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D10a_contrastive_0.05'] = run_4fold_mert_experiment(
        'D10a_contrastive_0.05', 'MERT + contrastive auxiliary loss (lambda=0.05)',
        make_contrastive_model_005, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D10a_contrastive_0.05', ALL_RESULTS['D10a_contrastive_0.05'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D10b: Contrastive Auxiliary Loss (lambda=0.1)
if should_run_experiment('D10b_contrastive_0.1', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE36_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_contrastive_model_01(cfg):
        return ContrastiveAuxiliaryModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            contrastive_lambda=0.1,
            temperature=0.07,
            pooling='attention',
            contrastive_type='supervised',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D10b_contrastive_0.1'] = run_4fold_mert_experiment(
        'D10b_contrastive_0.1', 'MERT + contrastive auxiliary loss (lambda=0.1)',
        make_contrastive_model_01, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D10b_contrastive_0.1', ALL_RESULTS['D10b_contrastive_0.1'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D10c: Contrastive Auxiliary Loss (lambda=0.2)
if should_run_experiment('D10c_contrastive_0.2', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE36_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_contrastive_model_02(cfg):
        return ContrastiveAuxiliaryModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            contrastive_lambda=0.2,
            temperature=0.07,
            pooling='attention',
            contrastive_type='supervised',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D10c_contrastive_0.2'] = run_4fold_mert_experiment(
        'D10c_contrastive_0.2', 'MERT + contrastive auxiliary loss (lambda=0.2)',
        make_contrastive_model_02, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D10c_contrastive_0.2', ALL_RESULTS['D10c_contrastive_0.2'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

In [None]:
# D10d: Contrastive Warmup (lambda decays from 0.5 to 0.05)
if should_run_experiment('D10d_contrastive_warmup', CHECKPOINT_ROOT, RESULTS_DIR, GDRIVE_RESULTS, PHASE36_COMPLETED):
    extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

    def make_contrastive_warmup_model(cfg):
        return ContrastiveWarmupModel(
            input_dim=cfg['input_dim'],
            hidden_dim=cfg['hidden_dim'],
            dropout=cfg['dropout'],
            learning_rate=cfg['learning_rate'],
            weight_decay=cfg['weight_decay'],
            contrastive_lambda_start=0.5,
            contrastive_lambda_end=0.05,
            temperature=0.07,
            pooling='attention',
            max_epochs=cfg['max_epochs'],
        )

    ALL_RESULTS['D10d_contrastive_warmup'] = run_4fold_mert_experiment(
        'D10d_contrastive_warmup', 'MERT + contrastive warmup (0.5 -> 0.05)',
        make_contrastive_warmup_model, DEFAULT_MERT_DIR, LABELS, FOLD_ASSIGNMENTS,
        BASE_CONFIG, CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR
    )
    sync_experiment_to_gdrive(
        'D10d_contrastive_warmup', ALL_RESULTS['D10d_contrastive_warmup'],
        RESULTS_DIR, CHECKPOINT_ROOT, GDRIVE_RESULTS, ALL_RESULTS
    )

---
## Phase 3.5 + 3.6 Results Summary

In [None]:
# Phase 3.5 + 3.6 Results Summary
print("="*80)
print("PHASE 3.5 + 3.6 RESULTS SUMMARY (MuQ + Contrastive)")
print("="*80)

# Get best previous result for comparison
best_previous = max(
    ALL_RESULTS.get('B0_baseline', {}).get('summary', {}).get('avg_r2', 0),
    max([ALL_RESULTS.get(exp_id, {}).get('summary', {}).get('avg_r2', 0) 
         for exp_id in PHASE3_EXPERIMENT_IDS if exp_id in ALL_RESULTS], default=0)
)

print(f"\nBest Previous R2: {best_previous:.4f}")
print()

print(f"{'Experiment':<30} {'Avg R2':>10} {'vs Best':>12}")
print("-"*55)

# MuQ experiments
print("\nMuQ Experiments:")
for exp_id in PHASE35_EXPERIMENT_IDS:
    if exp_id in ALL_RESULTS:
        s = ALL_RESULTS[exp_id]['summary']
        diff = s['avg_r2'] - best_previous
        print(f"  {exp_id:<28} {s['avg_r2']:>10.4f} {diff:>+12.4f}")

# Contrastive experiments
print("\nContrastive Experiments:")
for exp_id in PHASE36_EXPERIMENT_IDS:
    if exp_id in ALL_RESULTS:
        s = ALL_RESULTS[exp_id]['summary']
        diff = s['avg_r2'] - best_previous
        print(f"  {exp_id:<28} {s['avg_r2']:>10.4f} {diff:>+12.4f}")

print("="*55)

# Find overall best
all_new_exps = PHASE35_EXPERIMENT_IDS + PHASE36_EXPERIMENT_IDS
new_results = [(ALL_RESULTS.get(exp_id, {}).get('summary', {}).get('avg_r2', 0), exp_id)
               for exp_id in all_new_exps if exp_id in ALL_RESULTS]
if new_results:
    best_new = max(new_results)
    print(f"\nBest New Experiment: {best_new[1]} (R2={best_new[0]:.4f})")
    print(f"Improvement over Previous Best: {best_new[0] - best_previous:+.4f}")

# Final sync
with open(RESULTS_DIR / 'all_results_with_muq_contrastive.json', 'w') as f:
    json.dump(ALL_RESULTS, f, indent=2)

print("\n\nSyncing to Google Drive...")
run_rclone(['rclone', 'copy', str(RESULTS_DIR), GDRIVE_RESULTS], "Syncing results")
run_rclone(['rclone', 'copy', str(CHECKPOINT_ROOT), f"{GDRIVE_RESULTS}/checkpoints"], "Syncing checkpoints")
print("Done!")