# Phase 2: Audio Baseline Experiments

Comprehensive experiments for the ISMIR paper. Run all cells sequentially.

## Experiments
- **B0**: Baseline re-run (MERT+MLP, L13-24, mean pool)
- **A1-A3**: Baselines (linear probe, Mel-CNN, raw statistics)
- **B1a-B1d**: Layer ablation (1-6, 7-12, 13-24, 1-24)
- **B2a-B2c**: Pooling ablation (max, attention, LSTM)
- **C1a-C1b**: Loss ablation (hybrid MSE+CCC, pure CCC)

## Requirements
- Thunder Compute A100 (80GB VRAM)
- rclone configured with `gdrive:` remote
- ~14-15 hours runtime
- ~16GB storage

---
## 1. Setup

In [None]:
# Cell 1: GPU Check
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("GPU required for Phase 2 experiments")

In [None]:
# Cell 2: Install dependencies
!pip install transformers librosa soundfile pytorch_lightning nnAudio --quiet
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Cell 3: Imports and reproducibility
import json
import subprocess
import time
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import librosa
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from nnAudio.features import MelSpectrogram
from scipy import stats
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import AutoModel, AutoProcessor

warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('medium')

SEED = 42
pl.seed_everything(SEED, workers=True)

print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")
print(f"Seed: {SEED}")

In [None]:
# Cell 4: Paths and data download
DATA_ROOT = Path('/tmp/phase2')
AUDIO_DIR = DATA_ROOT / 'audio'
LABEL_DIR = DATA_ROOT / 'labels'
MERT_CACHE_ROOT = DATA_ROOT / 'mert_cache'
MEL_CACHE_DIR = DATA_ROOT / 'mel_cache'
STATS_CACHE_DIR = DATA_ROOT / 'stats_cache'
CHECKPOINT_ROOT = DATA_ROOT / 'checkpoints'
RESULTS_DIR = DATA_ROOT / 'results'
LOG_DIR = DATA_ROOT / 'logs'

# Google Drive paths
GDRIVE_AUDIO = 'gdrive:crescendai_data/audio_baseline/percepiano_rendered'
GDRIVE_LABELS = 'gdrive:crescendai_data/percepiano_labels'
GDRIVE_FOLDS = 'gdrive:crescendai_data/audio_baseline/audio_fold_assignments.json'
GDRIVE_MERT_CACHE = 'gdrive:crescendai_data/audio_baseline/mert_embeddings'
GDRIVE_RESULTS = 'gdrive:crescendai_data/checkpoints/audio_phase2'

# Create directories
for d in [AUDIO_DIR, LABEL_DIR, MERT_CACHE_ROOT, MEL_CACHE_DIR, STATS_CACHE_DIR, 
          CHECKPOINT_ROOT, RESULTS_DIR, LOG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def run_rclone(cmd, description):
    """Run rclone command with error checking."""
    print(f"{description}...")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Warning: {result.stderr}")
    return result

# Check rclone
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone 'gdrive' remote not configured")
print("rclone: OK")

# Download audio
run_rclone(['rclone', 'copy', GDRIVE_AUDIO, str(AUDIO_DIR), '--progress'], "Downloading audio")
print(f"Audio files: {len(list(AUDIO_DIR.glob('*.wav')))}")

# Download labels
run_rclone(['rclone', 'copy', GDRIVE_LABELS, str(LABEL_DIR)], "Downloading labels")

# Download fold assignments
FOLD_FILE = DATA_ROOT / 'folds.json'
run_rclone(['rclone', 'copyto', GDRIVE_FOLDS, str(FOLD_FILE)], "Downloading folds")

# Load labels and folds
LABEL_FILE = LABEL_DIR / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
with open(LABEL_FILE) as f:
    LABELS = json.load(f)
with open(FOLD_FILE) as f:
    FOLD_ASSIGNMENTS = json.load(f)

print(f"Labels: {len(LABELS)} segments")
print(f"Folds: {list(FOLD_ASSIGNMENTS.keys())}")

In [None]:
# Cell 5: Restore existing MERT cache (L13-24)
DEFAULT_MERT_DIR = MERT_CACHE_ROOT / 'L13-24'
DEFAULT_MERT_DIR.mkdir(parents=True, exist_ok=True)

print("Checking for existing MERT cache (L13-24)...")
result = subprocess.run(['rclone', 'lsf', GDRIVE_MERT_CACHE], capture_output=True, text=True)

if result.returncode == 0 and result.stdout.strip():
    remote_files = [f for f in result.stdout.strip().split('\n') if f.endswith('.pt')]
    if remote_files:
        print(f"Found {len(remote_files)} cached embeddings. Restoring...")
        run_rclone(['rclone', 'copy', GDRIVE_MERT_CACHE, str(DEFAULT_MERT_DIR)], "Restoring MERT cache")
        print(f"Restored: {len(list(DEFAULT_MERT_DIR.glob('*.pt')))} embeddings")
else:
    print("No existing cache found. Will extract fresh.")

print(f"\nSetup complete. Data root: {DATA_ROOT}")

---
## 2. Shared Infrastructure

In [None]:
# Cell 6: Constants and configuration
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity",
    "timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

DIMENSION_CATEGORIES = {
    "timing": ["timing"],
    "articulation": ["articulation_length", "articulation_touch"],
    "pedal": ["pedal_amount", "pedal_clarity"],
    "timbre": ["timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness"],
    "dynamics": ["dynamic_range"],
    "tempo_space": ["tempo", "space", "balance", "drama"],
    "emotion": ["mood_valence", "mood_energy", "mood_imagination"],
    "interpretation": ["sophistication", "interpretation"],
}

# Base configuration (modified per experiment)
BASE_CONFIG = {
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'dropout': 0.2,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'gradient_clip_val': 1.0,
    'batch_size': 64,
    'max_epochs': 200,
    'patience': 15,
    'max_frames': 1000,
    'n_folds': 4,
    'num_workers': 2,
    'seed': SEED,
}

# Track all results
ALL_RESULTS = {}

print(f"Dimensions: {len(PERCEPIANO_DIMENSIONS)}")
print(f"Base config: batch_size={BASE_CONFIG['batch_size']}, max_epochs={BASE_CONFIG['max_epochs']}")

In [None]:
# Cell 7: MERT Extractor (configurable layers)
class MERTLayerExtractor:
    """MERT-330M extractor with configurable layer range."""
    
    def __init__(self, layer_start: int = 13, layer_end: int = 25, cache_dir: Optional[Path] = None):
        self.layer_start = layer_start
        self.layer_end = layer_end
        self.target_sr = 24000
        self.cache_dir = Path(cache_dir) if cache_dir else None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Loading MERT-v1-330M (layers {layer_start}-{layer_end-1}) on {self.device}...")
        self.processor = AutoProcessor.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)
        self.model = AutoModel.from_pretrained(
            "m-a-p/MERT-v1-330M",
            output_hidden_states=True,
            trust_remote_code=True,
        ).to(self.device)
        self.model.eval()
        print(f"Model loaded. Hidden size: {self.model.config.hidden_size}")
    
    def get_cache_path(self, key: str) -> Path:
        return self.cache_dir / f"{key}.pt"
    
    @torch.no_grad()
    def extract_from_file(self, audio_path: Path, use_cache: bool = True) -> torch.Tensor:
        audio_path = Path(audio_path)
        key = audio_path.stem
        
        if use_cache and self.cache_dir:
            cache_path = self.get_cache_path(key)
            if cache_path.exists():
                return torch.load(cache_path, weights_only=True)
        
        audio, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
        inputs = self.processor(audio, sampling_rate=self.target_sr, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        outputs = self.model(**inputs)
        hidden_states = outputs.hidden_states[self.layer_start:self.layer_end]
        embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0).cpu()
        
        if use_cache and self.cache_dir:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            torch.save(embeddings, self.get_cache_path(key))
        
        return embeddings


def extract_mert_for_layer_range(layer_start: int, layer_end: int, audio_dir: Path, 
                                  cache_dir: Path, keys: List[str]) -> int:
    """Extract MERT embeddings for a specific layer range. Returns count extracted."""
    cache_dir.mkdir(parents=True, exist_ok=True)
    cached = {p.stem for p in cache_dir.glob('*.pt')}
    to_extract = [k for k in keys if k not in cached]
    
    if not to_extract:
        print(f"All {len(keys)} embeddings already cached.")
        return 0
    
    print(f"Extracting {len(to_extract)} embeddings (layers {layer_start}-{layer_end-1})...")
    extractor = MERTLayerExtractor(layer_start, layer_end, cache_dir)
    
    for key in tqdm(to_extract, desc=f"MERT L{layer_start}-{layer_end-1}"):
        audio_path = audio_dir / f"{key}.wav"
        if audio_path.exists():
            extractor.extract_from_file(audio_path)
    
    del extractor
    torch.cuda.empty_cache()
    return len(to_extract)


print("MERT extractor defined.")

In [None]:
# Cell 8: Mel spectrogram extractor (nnAudio)
class MelExtractor:
    """GPU-accelerated mel spectrogram extraction using nnAudio."""
    
    def __init__(self, cache_dir: Optional[Path] = None, sr: int = 24000):
        self.sr = sr
        self.cache_dir = Path(cache_dir) if cache_dir else None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.mel_spec = MelSpectrogram(
            sr=sr, n_fft=2048, hop_length=512, n_mels=128,
            fmin=20, fmax=8000, trainable_mel=False, trainable_STFT=False
        ).to(self.device)
    
    def extract_from_file(self, audio_path: Path, use_cache: bool = True) -> torch.Tensor:
        audio_path = Path(audio_path)
        key = audio_path.stem
        
        if use_cache and self.cache_dir:
            cache_path = self.cache_dir / f"{key}.pt"
            if cache_path.exists():
                return torch.load(cache_path, weights_only=True)
        
        audio, _ = librosa.load(audio_path, sr=self.sr, mono=True)
        audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            mel = self.mel_spec(audio_tensor)  # [1, n_mels, time]
            mel = mel.squeeze(0).cpu()  # [n_mels, time]
        
        if use_cache and self.cache_dir:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            torch.save(mel, self.cache_dir / f"{key}.pt")
        
        return mel


def extract_mel_spectrograms(audio_dir: Path, cache_dir: Path, keys: List[str]) -> int:
    """Extract mel spectrograms for all keys. Returns count extracted."""
    cache_dir.mkdir(parents=True, exist_ok=True)
    cached = {p.stem for p in cache_dir.glob('*.pt')}
    to_extract = [k for k in keys if k not in cached]
    
    if not to_extract:
        print(f"All {len(keys)} mel spectrograms already cached.")
        return 0
    
    print(f"Extracting {len(to_extract)} mel spectrograms...")
    extractor = MelExtractor(cache_dir)
    
    for key in tqdm(to_extract, desc="Mel extraction"):
        audio_path = audio_dir / f"{key}.wav"
        if audio_path.exists():
            extractor.extract_from_file(audio_path)
    
    return len(to_extract)


print("Mel extractor defined.")

In [None]:
# Cell 9: Audio statistics extractor
def extract_audio_statistics(audio: np.ndarray, sr: int = 24000) -> np.ndarray:
    """Extract 49-dim hand-crafted audio features."""
    features = []
    
    # Energy features (3)
    rms = librosa.feature.rms(y=audio)[0]
    features.extend([rms.mean(), rms.std(), rms.max()])
    
    # Spectral features (8)
    cent = librosa.feature.spectral_centroid(y=audio, sr=sr)[0]
    bw = librosa.feature.spectral_bandwidth(y=audio, sr=sr)[0]
    rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr)[0]
    zcr = librosa.feature.zero_crossing_rate(audio)[0]
    features.extend([cent.mean(), cent.std(), bw.mean(), bw.std(),
                     rolloff.mean(), rolloff.std(), zcr.mean(), zcr.std()])
    
    # MFCCs (26 = 13 coeffs x mean/std)
    mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
    features.extend(mfcc.mean(axis=1).tolist())
    features.extend(mfcc.std(axis=1).tolist())
    
    # Chroma (12 = 12 bins x mean)
    chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
    features.extend(chroma.mean(axis=1).tolist())
    
    return np.array(features, dtype=np.float32)  # Shape: (49,)


def extract_statistics_for_all(audio_dir: Path, cache_dir: Path, keys: List[str]) -> int:
    """Extract audio statistics for all keys. Returns count extracted."""
    cache_dir.mkdir(parents=True, exist_ok=True)
    cached = {p.stem for p in cache_dir.glob('*.pt')}
    to_extract = [k for k in keys if k not in cached]
    
    if not to_extract:
        print(f"All {len(keys)} statistics already cached.")
        return 0
    
    print(f"Extracting {len(to_extract)} audio statistics...")
    
    for key in tqdm(to_extract, desc="Stats extraction"):
        audio_path = audio_dir / f"{key}.wav"
        if audio_path.exists():
            audio, sr = librosa.load(audio_path, sr=24000, mono=True)
            stats_arr = extract_audio_statistics(audio, sr)
            torch.save(torch.from_numpy(stats_arr), cache_dir / f"{key}.pt")
    
    return len(to_extract)


print("Statistics extractor defined.")

In [None]:
# Cell 10: Dataset classes
class MERTDataset(Dataset):
    """Dataset for MERT embeddings."""
    
    def __init__(self, cache_dir: Path, labels: dict, fold_assignments: dict,
                 fold_id: int, mode: str, max_frames: int = 1000):
        self.cache_dir = Path(cache_dir)
        self.max_frames = max_frames
        
        available = {p.stem for p in self.cache_dir.glob('*.pt')}
        
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:  # train
            valid_keys = set()
            for i in range(4):
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, torch.tensor(labels[k][:19], dtype=torch.float32))
                        for k in valid_keys if k in available and k in labels]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        key, label = self.samples[idx]
        emb = torch.load(self.cache_dir / f"{key}.pt", weights_only=True)
        if emb.shape[0] > self.max_frames:
            emb = emb[:self.max_frames]
        return {"embeddings": emb, "labels": label, "key": key, "length": emb.shape[0]}


class MelDataset(Dataset):
    """Dataset for mel spectrograms."""
    
    def __init__(self, cache_dir: Path, labels: dict, fold_assignments: dict,
                 fold_id: int, mode: str, max_frames: int = 2000):
        self.cache_dir = Path(cache_dir)
        self.max_frames = max_frames
        
        available = {p.stem for p in self.cache_dir.glob('*.pt')}
        
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:
            valid_keys = set()
            for i in range(4):
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, torch.tensor(labels[k][:19], dtype=torch.float32))
                        for k in valid_keys if k in available and k in labels]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        key, label = self.samples[idx]
        mel = torch.load(self.cache_dir / f"{key}.pt", weights_only=True)  # [128, T]
        if mel.shape[1] > self.max_frames:
            mel = mel[:, :self.max_frames]
        return {"mel": mel, "labels": label, "key": key, "length": mel.shape[1]}


class StatsDataset(Dataset):
    """Dataset for audio statistics."""
    
    def __init__(self, cache_dir: Path, labels: dict, fold_assignments: dict,
                 fold_id: int, mode: str):
        self.cache_dir = Path(cache_dir)
        
        available = {p.stem for p in self.cache_dir.glob('*.pt')}
        
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:
            valid_keys = set()
            for i in range(4):
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, torch.tensor(labels[k][:19], dtype=torch.float32))
                        for k in valid_keys if k in available and k in labels]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        key, label = self.samples[idx]
        stats = torch.load(self.cache_dir / f"{key}.pt", weights_only=True)
        return {"features": stats, "labels": label, "key": key}


def mert_collate_fn(batch):
    embs = [b["embeddings"] for b in batch]
    labels = torch.stack([b["labels"] for b in batch])
    lengths = torch.tensor([b["length"] for b in batch])
    padded = pad_sequence(embs, batch_first=True)
    mask = torch.arange(padded.shape[1]).unsqueeze(0) < lengths.unsqueeze(1)
    return {"embeddings": padded, "attention_mask": mask, "labels": labels, 
            "keys": [b["key"] for b in batch], "lengths": lengths}


def mel_collate_fn(batch):
    mels = [b["mel"] for b in batch]  # Each is [128, T]
    labels = torch.stack([b["labels"] for b in batch])
    lengths = torch.tensor([b["length"] for b in batch])
    max_len = max(m.shape[1] for m in mels)
    padded = torch.zeros(len(mels), 128, max_len)
    for i, m in enumerate(mels):
        padded[i, :, :m.shape[1]] = m
    return {"mel": padded, "labels": labels, "keys": [b["key"] for b in batch], "lengths": lengths}


def stats_collate_fn(batch):
    features = torch.stack([b["features"] for b in batch])
    labels = torch.stack([b["labels"] for b in batch])
    return {"features": features, "labels": labels, "keys": [b["key"] for b in batch]}


print("Dataset classes defined.")

In [None]:
# Cell 11: Model classes

# Loss functions
def ccc_loss(pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """Concordance Correlation Coefficient loss."""
    pred_mean = pred.mean(dim=0)
    target_mean = target.mean(dim=0)
    pred_var = pred.var(dim=0, unbiased=False)
    target_var = target.var(dim=0, unbiased=False)
    covar = ((pred - pred_mean) * (target - target_mean)).mean(dim=0)
    ccc = (2 * covar) / (pred_var + target_var + (pred_mean - target_mean)**2 + eps)
    return (1 - ccc).mean()


class BaseMERTModel(pl.LightningModule):
    """Base class for MERT-based models."""
    
    def __init__(self, input_dim=1024, hidden_dim=512, num_labels=19, dropout=0.2,
                 learning_rate=1e-4, weight_decay=1e-5, pooling="mean", 
                 loss_type="mse", max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.pooling = pooling
        self.loss_type = loss_type
        self.max_epochs = max_epochs
        
        # Attention pooling
        if pooling == "attention":
            self.attn = nn.Sequential(
                nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1)
            )
        
        # LSTM pooling
        if pooling == "lstm":
            self.lstm = nn.LSTM(input_dim, hidden_dim // 2, batch_first=True, 
                               bidirectional=True, num_layers=1)
            self.lstm_attn = nn.Sequential(
                nn.Linear(hidden_dim, 128), nn.Tanh(), nn.Linear(128, 1)
            )
            input_dim = hidden_dim  # LSTM output dim
        
        # MLP head
        self.clf = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels), nn.Sigmoid(),
        )
        
        self.mse_loss = nn.MSELoss()
        self.val_outputs = []
    
    def pool(self, x, mask=None, lengths=None):
        if self.pooling == "mean":
            if mask is not None:
                m = mask.unsqueeze(-1).float()
                return (x * m).sum(1) / m.sum(1).clamp(min=1)
            return x.mean(1)
        elif self.pooling == "max":
            if mask is not None:
                x = x.masked_fill(~mask.unsqueeze(-1), float('-inf'))
            return x.max(1)[0]
        elif self.pooling == "attention":
            scores = self.attn(x).squeeze(-1)
            if mask is not None:
                scores = scores.masked_fill(~mask, float('-inf'))
            w = torch.softmax(scores, dim=-1).unsqueeze(-1)
            return (x * w).sum(1)
        elif self.pooling == "lstm":
            if lengths is not None:
                packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
                lstm_out, _ = self.lstm(packed)
                x, _ = pad_packed_sequence(lstm_out, batch_first=True)
            else:
                x, _ = self.lstm(x)
            # Attention over LSTM outputs
            scores = self.lstm_attn(x).squeeze(-1)
            if mask is not None:
                # Adjust mask size if needed
                if mask.shape[1] > x.shape[1]:
                    mask = mask[:, :x.shape[1]]
                scores = scores.masked_fill(~mask, float('-inf'))
            w = torch.softmax(scores, dim=-1).unsqueeze(-1)
            return (x * w).sum(1)
        return x.mean(1)
    
    def forward(self, x, mask=None, lengths=None):
        pooled = self.pool(x, mask, lengths)
        return self.clf(pooled)
    
    def compute_loss(self, pred, target):
        if self.loss_type == "mse":
            return self.mse_loss(pred, target)
        elif self.loss_type == "ccc":
            return ccc_loss(pred, target)
        elif self.loss_type == "hybrid":
            return self.mse_loss(pred, target) + 0.5 * ccc_loss(pred, target)
        return self.mse_loss(pred, target)
    
    def training_step(self, batch, idx):
        pred = self(batch["embeddings"], batch.get("attention_mask"), batch.get("lengths"))
        loss = self.compute_loss(pred, batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        pred = self(batch["embeddings"], batch.get("attention_mask"), batch.get("lengths"))
        self.log("val_loss", self.mse_loss(pred, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": pred.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.max_epochs, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


class LinearProbeModel(pl.LightningModule):
    """Simple linear probe on MERT embeddings."""
    
    def __init__(self, input_dim=1024, num_labels=19, learning_rate=1e-4, 
                 weight_decay=1e-5, max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.max_epochs = max_epochs
        
        self.linear = nn.Linear(input_dim, num_labels)
        self.loss_fn = nn.MSELoss()
        self.val_outputs = []
    
    def forward(self, x, mask=None, lengths=None):
        # Mean pooling
        if mask is not None:
            m = mask.unsqueeze(-1).float()
            pooled = (x * m).sum(1) / m.sum(1).clamp(min=1)
        else:
            pooled = x.mean(1)
        return torch.sigmoid(self.linear(pooled))
    
    def training_step(self, batch, idx):
        pred = self(batch["embeddings"], batch.get("attention_mask"))
        loss = self.loss_fn(pred, batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        pred = self(batch["embeddings"], batch.get("attention_mask"))
        self.log("val_loss", self.loss_fn(pred, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": pred.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.max_epochs, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


class MelCNNModel(pl.LightningModule):
    """4-layer CNN on mel spectrograms."""
    
    def __init__(self, hidden_dim=512, num_labels=19, dropout=0.2,
                 learning_rate=1e-4, weight_decay=1e-5, max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.max_epochs = max_epochs
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        
        self.head = nn.Sequential(
            nn.Linear(256, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels), nn.Sigmoid(),
        )
        
        self.loss_fn = nn.MSELoss()
        self.val_outputs = []
    
    def forward(self, mel):
        # mel: [B, 128, T]
        x = mel.unsqueeze(1)  # [B, 1, 128, T]
        x = self.conv(x)  # [B, 256, 1, 1]
        x = x.flatten(1)  # [B, 256]
        return self.head(x)
    
    def training_step(self, batch, idx):
        pred = self(batch["mel"])
        loss = self.loss_fn(pred, batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        pred = self(batch["mel"])
        self.log("val_loss", self.loss_fn(pred, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": pred.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.max_epochs, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


class StatsMLPModel(pl.LightningModule):
    """MLP on hand-crafted audio statistics."""
    
    def __init__(self, input_dim=49, hidden_dim=256, num_labels=19, dropout=0.2,
                 learning_rate=1e-4, weight_decay=1e-5, max_epochs=200):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.max_epochs = max_epochs
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels), nn.Sigmoid(),
        )
        
        self.loss_fn = nn.MSELoss()
        self.val_outputs = []
    
    def forward(self, x):
        return self.mlp(x)
    
    def training_step(self, batch, idx):
        pred = self(batch["features"])
        loss = self.loss_fn(pred, batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        pred = self(batch["features"])
        self.log("val_loss", self.loss_fn(pred, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": pred.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.max_epochs, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


print("Model classes defined.")

In [None]:
# Cell 12: Training and evaluation functions
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger


def bootstrap_r2(y_true: np.ndarray, y_pred: np.ndarray, n_bootstrap: int = 1000) -> Tuple[float, float, float]:
    """Compute bootstrap 95% CI for R2."""
    np.random.seed(SEED)
    n_samples = len(y_true)
    r2_scores = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n_samples, n_samples, replace=True)
        r2_scores.append(r2_score(y_true[idx], y_pred[idx]))
    return np.percentile(r2_scores, [2.5, 50, 97.5])


def compute_comprehensive_metrics(all_preds: np.ndarray, all_labels: np.ndarray) -> dict:
    """Compute all metrics for experiment results."""
    overall_r2 = r2_score(all_labels, all_preds)
    overall_mae = mean_absolute_error(all_labels, all_preds)
    overall_rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
    
    # Bootstrap CI
    ci = bootstrap_r2(all_labels, all_preds)
    
    # Per-dimension metrics
    per_dim = {}
    for i, dim in enumerate(PERCEPIANO_DIMENSIONS):
        y_true, y_pred = all_labels[:, i], all_preds[:, i]
        pearson, p_val = stats.pearsonr(y_true, y_pred)
        per_dim[dim] = {
            "r2": float(r2_score(y_true, y_pred)),
            "mae": float(mean_absolute_error(y_true, y_pred)),
            "pearson": float(pearson),
            "p_value": float(p_val),
            "label_mean": float(y_true.mean()),
            "label_std": float(y_true.std()),
            "pred_mean": float(y_pred.mean()),
            "pred_std": float(y_pred.std()),
        }
    
    # Dispersion ratio
    avg_label_std = np.mean([all_labels[:, i].std() for i in range(19)])
    avg_pred_std = np.mean([all_preds[:, i].std() for i in range(19)])
    dispersion_ratio = avg_pred_std / avg_label_std if avg_label_std > 0 else 0
    
    return {
        "overall_r2": float(overall_r2),
        "r2_ci_95": [float(ci[0]), float(ci[2])],
        "overall_mae": float(overall_mae),
        "overall_rmse": float(overall_rmse),
        "dispersion_ratio": float(dispersion_ratio),
        "per_dimension": per_dim,
    }


def experiment_completed(exp_id: str, checkpoint_dir: Path) -> bool:
    """Check if experiment has all fold checkpoints."""
    exp_dir = checkpoint_dir / exp_id
    if not exp_dir.exists():
        return False
    return all((exp_dir / f"fold{i}_best.ckpt").exists() for i in range(4))


def load_existing_results(exp_id: str, results_dir: Path) -> Optional[dict]:
    """Load results from JSON if exists."""
    results_file = results_dir / f"{exp_id}.json"
    if results_file.exists():
        with open(results_file) as f:
            return json.load(f)
    return None


def run_4fold_mert_experiment(
    exp_id: str,
    description: str,
    model_factory: Callable,
    mert_cache_dir: Path,
    config: dict,
) -> dict:
    """Run 4-fold CV for MERT-based experiment."""
    exp_checkpoint_dir = CHECKPOINT_ROOT / exp_id
    exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Check if already done
    existing = load_existing_results(exp_id, RESULTS_DIR)
    if existing and experiment_completed(exp_id, CHECKPOINT_ROOT):
        print(f"SKIP {exp_id}: already completed (R2={existing['summary']['avg_r2']:.4f})")
        return existing
    
    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {exp_id}")
    print(f"Description: {description}")
    print(f"{'='*70}")
    
    start_time = time.time()
    fold_results = {}
    all_preds, all_labels = [], []
    
    for fold in range(config['n_folds']):
        ckpt_path = exp_checkpoint_dir / f"fold{fold}_best.ckpt"
        
        # Create datasets
        train_ds = MERTDataset(mert_cache_dir, LABELS, FOLD_ASSIGNMENTS, fold, "train", config['max_frames'])
        val_ds = MERTDataset(mert_cache_dir, LABELS, FOLD_ASSIGNMENTS, fold, "val", config['max_frames'])
        
        if len(train_ds) == 0 or len(val_ds) == 0:
            print(f"Fold {fold}: No data available, skipping")
            continue
        
        train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True,
                              collate_fn=mert_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False,
                            collate_fn=mert_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        
        if ckpt_path.exists():
            print(f"Fold {fold}: Loading existing checkpoint")
            model = model_factory(config)
            model = model.__class__.load_from_checkpoint(ckpt_path)
        else:
            print(f"Fold {fold}: Training ({len(train_ds)} train, {len(val_ds)} val)")
            model = model_factory(config)
            
            callbacks = [
                ModelCheckpoint(dirpath=exp_checkpoint_dir, filename=f'fold{fold}_best',
                                monitor='val_r2', mode='max', save_top_k=1),
                EarlyStopping(monitor='val_r2', mode='max', patience=config['patience'], verbose=True),
            ]
            
            trainer = pl.Trainer(
                max_epochs=config['max_epochs'],
                callbacks=callbacks,
                logger=CSVLogger(save_dir=LOG_DIR, name=exp_id, version=f'fold{fold}'),
                accelerator='auto', devices=1,
                gradient_clip_val=config['gradient_clip_val'],
                enable_progress_bar=True,
                deterministic=True,
                log_every_n_steps=10,
            )
            
            trainer.fit(model, train_dl, val_dl)
            fold_results[fold] = float(callbacks[0].best_model_score or 0)
            
            # Reload best
            model = model.__class__.load_from_checkpoint(ckpt_path)
        
        # Evaluate
        model.eval().to('cuda')
        with torch.no_grad():
            for batch in val_dl:
                pred = model(batch["embeddings"].cuda(), batch["attention_mask"].cuda(), batch.get("lengths"))
                all_preds.append(pred.cpu().numpy())
                all_labels.append(batch["labels"].numpy())
        
        del model
        torch.cuda.empty_cache()
    
    # Aggregate results
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    metrics = compute_comprehensive_metrics(all_preds, all_labels)
    
    # If fold_results is empty (loaded from checkpoints), compute from metrics
    if not fold_results:
        fold_results = {i: metrics['overall_r2'] for i in range(4)}  # Approximate
    
    avg_r2 = np.mean(list(fold_results.values()))
    std_r2 = np.std(list(fold_results.values()))
    
    results = {
        "experiment_id": exp_id,
        "description": description,
        "config": config,
        "summary": {
            "avg_r2": float(avg_r2),
            "std_r2": float(std_r2),
            "r2_ci_95": metrics['r2_ci_95'],
            "overall_r2": metrics['overall_r2'],
            "overall_mae": metrics['overall_mae'],
            "dispersion_ratio": metrics['dispersion_ratio'],
        },
        "fold_results": {str(k): float(v) for k, v in fold_results.items()},
        "per_dimension": metrics['per_dimension'],
        "training_time_seconds": time.time() - start_time,
    }
    
    # Save results
    with open(RESULTS_DIR / f"{exp_id}.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n{exp_id} COMPLETE: R2={avg_r2:.4f} +/- {std_r2:.4f}, CI=[{metrics['r2_ci_95'][0]:.4f}, {metrics['r2_ci_95'][1]:.4f}]")
    
    return results


def run_4fold_mel_experiment(exp_id: str, description: str, config: dict) -> dict:
    """Run 4-fold CV for Mel-CNN experiment."""
    exp_checkpoint_dir = CHECKPOINT_ROOT / exp_id
    exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    existing = load_existing_results(exp_id, RESULTS_DIR)
    if existing and experiment_completed(exp_id, CHECKPOINT_ROOT):
        print(f"SKIP {exp_id}: already completed (R2={existing['summary']['avg_r2']:.4f})")
        return existing
    
    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {exp_id}")
    print(f"{'='*70}")
    
    start_time = time.time()
    fold_results = {}
    all_preds, all_labels = [], []
    
    for fold in range(config['n_folds']):
        ckpt_path = exp_checkpoint_dir / f"fold{fold}_best.ckpt"
        
        train_ds = MelDataset(MEL_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS, fold, "train")
        val_ds = MelDataset(MEL_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS, fold, "val")
        
        train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True,
                              collate_fn=mel_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False,
                            collate_fn=mel_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        
        if ckpt_path.exists():
            model = MelCNNModel.load_from_checkpoint(ckpt_path)
        else:
            print(f"Fold {fold}: Training")
            model = MelCNNModel(
                hidden_dim=config['hidden_dim'],
                learning_rate=config['learning_rate'],
                weight_decay=config['weight_decay'],
                max_epochs=config['max_epochs'],
            )
            
            callbacks = [
                ModelCheckpoint(dirpath=exp_checkpoint_dir, filename=f'fold{fold}_best',
                                monitor='val_r2', mode='max', save_top_k=1),
                EarlyStopping(monitor='val_r2', mode='max', patience=config['patience'], verbose=True),
            ]
            
            trainer = pl.Trainer(
                max_epochs=config['max_epochs'],
                callbacks=callbacks,
                logger=CSVLogger(save_dir=LOG_DIR, name=exp_id, version=f'fold{fold}'),
                accelerator='auto', devices=1,
                gradient_clip_val=config['gradient_clip_val'],
                enable_progress_bar=True,
                deterministic=True,
            )
            
            trainer.fit(model, train_dl, val_dl)
            fold_results[fold] = float(callbacks[0].best_model_score or 0)
            model = MelCNNModel.load_from_checkpoint(ckpt_path)
        
        model.eval().to('cuda')
        with torch.no_grad():
            for batch in val_dl:
                pred = model(batch["mel"].cuda())
                all_preds.append(pred.cpu().numpy())
                all_labels.append(batch["labels"].numpy())
        
        del model
        torch.cuda.empty_cache()
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    metrics = compute_comprehensive_metrics(all_preds, all_labels)
    
    if not fold_results:
        fold_results = {i: metrics['overall_r2'] for i in range(4)}
    
    avg_r2 = np.mean(list(fold_results.values()))
    std_r2 = np.std(list(fold_results.values()))
    
    results = {
        "experiment_id": exp_id,
        "description": description,
        "summary": {
            "avg_r2": float(avg_r2),
            "std_r2": float(std_r2),
            "r2_ci_95": metrics['r2_ci_95'],
            "overall_r2": metrics['overall_r2'],
            "overall_mae": metrics['overall_mae'],
            "dispersion_ratio": metrics['dispersion_ratio'],
        },
        "fold_results": {str(k): float(v) for k, v in fold_results.items()},
        "per_dimension": metrics['per_dimension'],
        "training_time_seconds": time.time() - start_time,
    }
    
    with open(RESULTS_DIR / f"{exp_id}.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n{exp_id} COMPLETE: R2={avg_r2:.4f}")
    return results


def run_4fold_stats_experiment(exp_id: str, description: str, config: dict) -> dict:
    """Run 4-fold CV for statistics MLP experiment."""
    exp_checkpoint_dir = CHECKPOINT_ROOT / exp_id
    exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    existing = load_existing_results(exp_id, RESULTS_DIR)
    if existing and experiment_completed(exp_id, CHECKPOINT_ROOT):
        print(f"SKIP {exp_id}: already completed (R2={existing['summary']['avg_r2']:.4f})")
        return existing
    
    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {exp_id}")
    print(f"{'='*70}")
    
    start_time = time.time()
    fold_results = {}
    all_preds, all_labels = [], []
    
    for fold in range(config['n_folds']):
        ckpt_path = exp_checkpoint_dir / f"fold{fold}_best.ckpt"
        
        train_ds = StatsDataset(STATS_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS, fold, "train")
        val_ds = StatsDataset(STATS_CACHE_DIR, LABELS, FOLD_ASSIGNMENTS, fold, "val")
        
        train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True,
                              collate_fn=stats_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False,
                            collate_fn=stats_collate_fn, num_workers=config['num_workers'], pin_memory=True)
        
        if ckpt_path.exists():
            model = StatsMLPModel.load_from_checkpoint(ckpt_path)
        else:
            print(f"Fold {fold}: Training")
            model = StatsMLPModel(
                input_dim=49,
                hidden_dim=256,
                learning_rate=config['learning_rate'],
                weight_decay=config['weight_decay'],
                max_epochs=config['max_epochs'],
            )
            
            callbacks = [
                ModelCheckpoint(dirpath=exp_checkpoint_dir, filename=f'fold{fold}_best',
                                monitor='val_r2', mode='max', save_top_k=1),
                EarlyStopping(monitor='val_r2', mode='max', patience=config['patience'], verbose=True),
            ]
            
            trainer = pl.Trainer(
                max_epochs=config['max_epochs'],
                callbacks=callbacks,
                logger=CSVLogger(save_dir=LOG_DIR, name=exp_id, version=f'fold{fold}'),
                accelerator='auto', devices=1,
                gradient_clip_val=config['gradient_clip_val'],
                enable_progress_bar=True,
                deterministic=True,
            )
            
            trainer.fit(model, train_dl, val_dl)
            fold_results[fold] = float(callbacks[0].best_model_score or 0)
            model = StatsMLPModel.load_from_checkpoint(ckpt_path)
        
        model.eval().to('cuda')
        with torch.no_grad():
            for batch in val_dl:
                pred = model(batch["features"].cuda())
                all_preds.append(pred.cpu().numpy())
                all_labels.append(batch["labels"].numpy())
        
        del model
        torch.cuda.empty_cache()
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    metrics = compute_comprehensive_metrics(all_preds, all_labels)
    
    if not fold_results:
        fold_results = {i: metrics['overall_r2'] for i in range(4)}
    
    avg_r2 = np.mean(list(fold_results.values()))
    std_r2 = np.std(list(fold_results.values()))
    
    results = {
        "experiment_id": exp_id,
        "description": description,
        "summary": {
            "avg_r2": float(avg_r2),
            "std_r2": float(std_r2),
            "r2_ci_95": metrics['r2_ci_95'],
            "overall_r2": metrics['overall_r2'],
            "overall_mae": metrics['overall_mae'],
            "dispersion_ratio": metrics['dispersion_ratio'],
        },
        "fold_results": {str(k): float(v) for k, v in fold_results.items()},
        "per_dimension": metrics['per_dimension'],
        "training_time_seconds": time.time() - start_time,
    }
    
    with open(RESULTS_DIR / f"{exp_id}.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n{exp_id} COMPLETE: R2={avg_r2:.4f}")
    return results


print("Training functions defined.")

---
## 3. Experiments

In [None]:
# Cell 13: B0 - Baseline Re-run
# Get all valid keys
ALL_KEYS = list(LABELS.keys())

# Ensure default MERT cache exists
extract_mert_for_layer_range(13, 25, AUDIO_DIR, DEFAULT_MERT_DIR, ALL_KEYS)

# Run baseline
config_b0 = BASE_CONFIG.copy()
config_b0['pooling'] = 'mean'
config_b0['loss_type'] = 'mse'

def model_factory_b0(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling=cfg['pooling'],
        loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['B0_baseline'] = run_4fold_mert_experiment(
    'B0_baseline', 'MERT+MLP, L13-24, mean pooling (baseline)',
    model_factory_b0, DEFAULT_MERT_DIR, config_b0
)

In [None]:
# Cell 14: A1 - Linear Probe
config_a1 = BASE_CONFIG.copy()

def model_factory_a1(cfg):
    return LinearProbeModel(
        input_dim=cfg['input_dim'],
        learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'],
        max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['A1_linear_probe'] = run_4fold_mert_experiment(
    'A1_linear_probe', 'Linear probe on MERT (L13-24)',
    model_factory_a1, DEFAULT_MERT_DIR, config_a1
)

In [None]:
# Cell 15: A2 - Mel-CNN (extraction)
print("Extracting mel spectrograms...")
extract_mel_spectrograms(AUDIO_DIR, MEL_CACHE_DIR, ALL_KEYS)
print(f"Mel cache: {len(list(MEL_CACHE_DIR.glob('*.pt')))} files")

In [None]:
# Cell 16: A2 - Mel-CNN (training)
config_a2 = BASE_CONFIG.copy()

ALL_RESULTS['A2_mel_cnn'] = run_4fold_mel_experiment(
    'A2_mel_cnn', '4-layer CNN on mel spectrograms',
    config_a2
)

In [None]:
# Cell 17: A3 - Raw Statistics (extraction)
print("Extracting audio statistics...")
extract_statistics_for_all(AUDIO_DIR, STATS_CACHE_DIR, ALL_KEYS)
print(f"Stats cache: {len(list(STATS_CACHE_DIR.glob('*.pt')))} files")

In [None]:
# Cell 18: A3 - Raw Statistics (training)
config_a3 = BASE_CONFIG.copy()

ALL_RESULTS['A3_raw_stats'] = run_4fold_stats_experiment(
    'A3_raw_stats', 'MLP on hand-crafted audio statistics (49-dim)',
    config_a3
)

In [None]:
# Cell 19: B1a - Layers 1-6
L1_6_DIR = MERT_CACHE_ROOT / 'L1-6'
extract_mert_for_layer_range(1, 7, AUDIO_DIR, L1_6_DIR, ALL_KEYS)

config_b1a = BASE_CONFIG.copy()
config_b1a['pooling'] = 'mean'
config_b1a['loss_type'] = 'mse'

ALL_RESULTS['B1a_layers_1-6'] = run_4fold_mert_experiment(
    'B1a_layers_1-6', 'MERT layers 1-6 (early/acoustic)',
    model_factory_b0, L1_6_DIR, config_b1a
)

In [None]:
# Cell 20: B1b - Layers 7-12
L7_12_DIR = MERT_CACHE_ROOT / 'L7-12'
extract_mert_for_layer_range(7, 13, AUDIO_DIR, L7_12_DIR, ALL_KEYS)

config_b1b = BASE_CONFIG.copy()
config_b1b['pooling'] = 'mean'
config_b1b['loss_type'] = 'mse'

ALL_RESULTS['B1b_layers_7-12'] = run_4fold_mert_experiment(
    'B1b_layers_7-12', 'MERT layers 7-12 (mid)',
    model_factory_b0, L7_12_DIR, config_b1b
)

In [None]:
# Cell 21: B1c - Layers 13-24 (sanity check, should match B0)
# Uses same cache as B0, just verifying
config_b1c = BASE_CONFIG.copy()
config_b1c['pooling'] = 'mean'
config_b1c['loss_type'] = 'mse'

ALL_RESULTS['B1c_layers_13-24'] = run_4fold_mert_experiment(
    'B1c_layers_13-24', 'MERT layers 13-24 (late/semantic)',
    model_factory_b0, DEFAULT_MERT_DIR, config_b1c
)

In [None]:
# Cell 22: B1d - All Layers 1-24
L1_24_DIR = MERT_CACHE_ROOT / 'L1-24'
extract_mert_for_layer_range(1, 25, AUDIO_DIR, L1_24_DIR, ALL_KEYS)

config_b1d = BASE_CONFIG.copy()
config_b1d['pooling'] = 'mean'
config_b1d['loss_type'] = 'mse'

ALL_RESULTS['B1d_layers_1-24'] = run_4fold_mert_experiment(
    'B1d_layers_1-24', 'MERT all layers 1-24',
    model_factory_b0, L1_24_DIR, config_b1d
)

In [None]:
# Cell 23: B2a - Max Pooling
config_b2a = BASE_CONFIG.copy()
config_b2a['pooling'] = 'max'
config_b2a['loss_type'] = 'mse'

def model_factory_max(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling='max',
        loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['B2a_max_pool'] = run_4fold_mert_experiment(
    'B2a_max_pool', 'MERT + max pooling',
    model_factory_max, DEFAULT_MERT_DIR, config_b2a
)

In [None]:
# Cell 24: B2b - Attention Pooling
config_b2b = BASE_CONFIG.copy()
config_b2b['pooling'] = 'attention'
config_b2b['loss_type'] = 'mse'

def model_factory_attn(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling='attention',
        loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['B2b_attention_pool'] = run_4fold_mert_experiment(
    'B2b_attention_pool', 'MERT + attention pooling',
    model_factory_attn, DEFAULT_MERT_DIR, config_b2b
)

In [None]:
# Cell 25: B2c - LSTM Pooling
config_b2c = BASE_CONFIG.copy()
config_b2c['pooling'] = 'lstm'
config_b2c['loss_type'] = 'mse'

def model_factory_lstm(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling='lstm',
        loss_type=cfg['loss_type'], max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['B2c_lstm_pool'] = run_4fold_mert_experiment(
    'B2c_lstm_pool', 'MERT + Bi-LSTM + attention pooling',
    model_factory_lstm, DEFAULT_MERT_DIR, config_b2c
)

In [None]:
# Cell 26: C1a - Hybrid Loss (MSE + CCC)
config_c1a = BASE_CONFIG.copy()
config_c1a['pooling'] = 'mean'
config_c1a['loss_type'] = 'hybrid'

def model_factory_hybrid(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling='mean',
        loss_type='hybrid', max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['C1a_hybrid_loss'] = run_4fold_mert_experiment(
    'C1a_hybrid_loss', 'MERT + MSE + 0.5*CCC loss',
    model_factory_hybrid, DEFAULT_MERT_DIR, config_c1a
)

In [None]:
# Cell 27: C1b - Pure CCC Loss
config_c1b = BASE_CONFIG.copy()
config_c1b['pooling'] = 'mean'
config_c1b['loss_type'] = 'ccc'

def model_factory_ccc(cfg):
    return BaseMERTModel(
        input_dim=cfg['input_dim'], hidden_dim=cfg['hidden_dim'],
        dropout=cfg['dropout'], learning_rate=cfg['learning_rate'],
        weight_decay=cfg['weight_decay'], pooling='mean',
        loss_type='ccc', max_epochs=cfg['max_epochs'],
    )

ALL_RESULTS['C1b_pure_ccc'] = run_4fold_mert_experiment(
    'C1b_pure_ccc', 'MERT + pure CCC loss',
    model_factory_ccc, DEFAULT_MERT_DIR, config_c1b
)

---
## 4. Final Summary

In [None]:
# Cell 28: Aggregate and display results
print("="*80)
print("PHASE 2 RESULTS SUMMARY")
print("="*80)

# Get baseline R2 for comparison
baseline_r2 = ALL_RESULTS.get('B0_baseline', {}).get('summary', {}).get('avg_r2', 0)

print(f"{'Experiment':<25} {'Avg R2':>10} {'95% CI':>20} {'vs B0':>10} {'Disp':>8}")
print("-"*80)

# Order of experiments
exp_order = [
    ('B0_baseline', 'B0: Baseline'),
    (None, '-' * 60),
    ('A1_linear_probe', 'A1: Linear Probe'),
    ('A2_mel_cnn', 'A2: Mel-CNN'),
    ('A3_raw_stats', 'A3: Raw Statistics'),
    (None, '-' * 60),
    ('B1a_layers_1-6', 'B1a: Layers 1-6'),
    ('B1b_layers_7-12', 'B1b: Layers 7-12'),
    ('B1c_layers_13-24', 'B1c: Layers 13-24'),
    ('B1d_layers_1-24', 'B1d: Layers 1-24'),
    (None, '-' * 60),
    ('B2a_max_pool', 'B2a: Max Pooling'),
    ('B2b_attention_pool', 'B2b: Attention Pool'),
    ('B2c_lstm_pool', 'B2c: LSTM Pool'),
    (None, '-' * 60),
    ('C1a_hybrid_loss', 'C1a: Hybrid Loss'),
    ('C1b_pure_ccc', 'C1b: Pure CCC'),
]

for exp_id, label in exp_order:
    if exp_id is None:
        print(label)
        continue
    
    if exp_id not in ALL_RESULTS:
        print(f"{label:<25} {'N/A':>10}")
        continue
    
    r = ALL_RESULTS[exp_id]
    s = r['summary']
    avg_r2 = s['avg_r2']
    ci = s.get('r2_ci_95', [0, 0])
    disp = s.get('dispersion_ratio', 0)
    diff = avg_r2 - baseline_r2 if exp_id != 'B0_baseline' else 0
    diff_str = f"{diff:+.3f}" if exp_id != 'B0_baseline' else '---'
    
    print(f"{label:<25} {avg_r2:>10.4f} [{ci[0]:.3f}, {ci[1]:.3f}] {diff_str:>10} {disp:>8.2f}")

print("="*80)

In [None]:
# Cell 29: Key findings
print("\nKEY FINDINGS")
print("="*80)

# MLP necessary?
if 'A1_linear_probe' in ALL_RESULTS and 'B0_baseline' in ALL_RESULTS:
    linear_r2 = ALL_RESULTS['A1_linear_probe']['summary']['avg_r2']
    mlp_r2 = ALL_RESULTS['B0_baseline']['summary']['avg_r2']
    print(f"MLP necessary: Linear R2={linear_r2:.4f} vs MLP R2={mlp_r2:.4f} (diff={mlp_r2-linear_r2:+.4f})")

# MERT necessary?
if 'A2_mel_cnn' in ALL_RESULTS and 'B0_baseline' in ALL_RESULTS:
    mel_r2 = ALL_RESULTS['A2_mel_cnn']['summary']['avg_r2']
    mert_r2 = ALL_RESULTS['B0_baseline']['summary']['avg_r2']
    print(f"MERT necessary: Mel-CNN R2={mel_r2:.4f} vs MERT R2={mert_r2:.4f} (diff={mert_r2-mel_r2:+.4f})")

# Best layers
layer_exps = ['B1a_layers_1-6', 'B1b_layers_7-12', 'B1c_layers_13-24', 'B1d_layers_1-24']
layer_results = [(e, ALL_RESULTS[e]['summary']['avg_r2']) for e in layer_exps if e in ALL_RESULTS]
if layer_results:
    best_layer = max(layer_results, key=lambda x: x[1])
    print(f"Best layers: {best_layer[0]} (R2={best_layer[1]:.4f})")

# Best pooling
pool_exps = ['B0_baseline', 'B2a_max_pool', 'B2b_attention_pool', 'B2c_lstm_pool']
pool_results = [(e, ALL_RESULTS[e]['summary']['avg_r2']) for e in pool_exps if e in ALL_RESULTS]
if pool_results:
    best_pool = max(pool_results, key=lambda x: x[1])
    print(f"Best pooling: {best_pool[0]} (R2={best_pool[1]:.4f})")

# Best loss (dispersion improvement)
loss_exps = ['B0_baseline', 'C1a_hybrid_loss', 'C1b_pure_ccc']
loss_results = [(e, ALL_RESULTS[e]['summary'].get('dispersion_ratio', 0)) for e in loss_exps if e in ALL_RESULTS]
if loss_results:
    best_loss = max(loss_results, key=lambda x: x[1])
    baseline_disp = ALL_RESULTS.get('B0_baseline', {}).get('summary', {}).get('dispersion_ratio', 0)
    print(f"Best loss for dispersion: {best_loss[0]} (disp={best_loss[1]:.2f} vs baseline {baseline_disp:.2f})")

print("="*80)

In [None]:
# Cell 30: Save and sync results
# Save aggregated results
all_results_file = RESULTS_DIR / 'phase2_all_results.json'
with open(all_results_file, 'w') as f:
    json.dump(ALL_RESULTS, f, indent=2)
print(f"Saved: {all_results_file}")

# Sync to Google Drive
print("\nSyncing results to Google Drive...")
run_rclone(['rclone', 'copy', str(RESULTS_DIR), GDRIVE_RESULTS, '-v'], "Syncing results")
run_rclone(['rclone', 'copy', str(CHECKPOINT_ROOT), f"{GDRIVE_RESULTS}/checkpoints", '-v'], "Syncing checkpoints")
run_rclone(['rclone', 'copy', str(LOG_DIR), f"{GDRIVE_RESULTS}/logs", '-v'], "Syncing logs")

# Also sync MERT caches for reproducibility
for layer_dir in MERT_CACHE_ROOT.iterdir():
    if layer_dir.is_dir():
        run_rclone(['rclone', 'copy', str(layer_dir), f"{GDRIVE_RESULTS}/mert_cache/{layer_dir.name}", '-v'],
                   f"Syncing MERT cache {layer_dir.name}")

print("\n" + "="*80)
print("PHASE 2 EXPERIMENTS COMPLETE")
print(f"Results: {GDRIVE_RESULTS}")
print("="*80)