# Audio Baseline for PercePiano (MERT-330M)

Train audio baseline using MERT-330M embeddings on Thunder Compute.

## What This Notebook Does

1. Download pre-rendered WAV files from Google Drive
2. Extract MERT-330M embeddings (GPU required)
3. Train 4-fold cross-validation
4. Evaluate and analyze results

## Target: R2 >= 0.25

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. MERT extraction will be very slow.")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install dependencies (Thunder Compute has PyTorch pre-installed)
!pip install transformers librosa soundfile pytorch_lightning --quiet

import subprocess
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone not configured. Run 'rclone config' to set up gdrive remote.")
print("rclone 'gdrive' remote: CONFIGURED")

In [None]:
# Core imports
import json
import subprocess
import warnings
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from sklearn.metrics import r2_score
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

# Reproducibility
SEED = 42
pl.seed_everything(SEED, workers=True)

print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")
print(f"Seed: {SEED}")

## Step 2: Download Data from Google Drive

In [None]:
# Paths
DATA_ROOT = Path('/tmp/audio_baseline')
AUDIO_DIR = DATA_ROOT / 'percepiano_rendered'
MERT_CACHE_DIR = DATA_ROOT / 'mert_embeddings'
CHECKPOINT_ROOT = DATA_ROOT / 'checkpoints'
LABEL_DIR = DATA_ROOT / 'labels'

# Google Drive paths
GDRIVE_AUDIO = 'gdrive:crescendai_data/audio_baseline/percepiano_rendered'
GDRIVE_LABELS = 'gdrive:crescendai_data/percepiano_labels'
GDRIVE_FOLDS = 'gdrive:crescendai_data/audio_baseline/audio_fold_assignments.json'
GDRIVE_CHECKPOINTS = 'gdrive:crescendai_data/checkpoints/audio_baseline'
GDRIVE_MERT_CACHE = 'gdrive:crescendai_data/audio_baseline/mert_embeddings'

# Create directories
for d in [AUDIO_DIR, MERT_CACHE_DIR, CHECKPOINT_ROOT, LABEL_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print(f"Data root: {DATA_ROOT}")

In [None]:
# Download pre-rendered audio from Google Drive
def run_rclone(cmd, description):
    """Run rclone command with error checking."""
    print(f"{description}...")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"rclone failed: {result.stderr}")
    return result

print(f"Source: {GDRIVE_AUDIO}")
print(f"Destination: {AUDIO_DIR}")

run_rclone(
    ['rclone', 'copy', GDRIVE_AUDIO, str(AUDIO_DIR), '--progress', '-v'],
    "Downloading pre-rendered audio files"
)

wav_count = len(list(AUDIO_DIR.glob('*.wav')))
print(f"\nDownloaded {wav_count} WAV files")

if wav_count == 0:
    raise RuntimeError("No WAV files downloaded! Run prepare_audio_baseline.py locally first.")

In [None]:
# Download labels and fold assignments
run_rclone(
    ['rclone', 'copy', GDRIVE_LABELS, str(LABEL_DIR), '-v'],
    "Downloading labels"
)

# Use copyto for single file (not copy which expects directory)
FOLD_FILE = DATA_ROOT / 'audio_fold_assignments.json'
run_rclone(
    ['rclone', 'copyto', GDRIVE_FOLDS, str(FOLD_FILE), '-v'],
    "Downloading fold assignments"
)

# Verify
LABEL_FILE = LABEL_DIR / 'label_2round_mean_reg_19_with0_rm_highstd0.json'

if not LABEL_FILE.exists():
    raise FileNotFoundError(f"Label file not found: {LABEL_FILE}")

with open(LABEL_FILE) as f:
    labels = json.load(f)
print(f"Labels: {len(labels)} segments")

if not FOLD_FILE.exists():
    raise FileNotFoundError(f"Fold file not found: {FOLD_FILE}")

with open(FOLD_FILE) as f:
    fold_assignments = json.load(f)
print(f"\nFold statistics:")
for fold_name, keys in fold_assignments.items():
    print(f"  {fold_name}: {len(keys)} samples")

In [None]:
# Check for existing MERT cache (resume capability)
print("Checking for existing MERT cache on Google Drive...")
result = subprocess.run(
    ['rclone', 'lsf', GDRIVE_MERT_CACHE],
    capture_output=True, text=True
)

if result.returncode == 0 and result.stdout.strip():
    remote_files = [f for f in result.stdout.strip().split('\n') if f.endswith('.pt')]
    if remote_files:
        print(f"Found {len(remote_files)} cached embeddings. Restoring...")
        run_rclone(
            ['rclone', 'copy', GDRIVE_MERT_CACHE, str(MERT_CACHE_DIR), '-v'],
            "Restoring MERT cache"
        )
        print(f"Restored {len(list(MERT_CACHE_DIR.glob('*.pt')))} embeddings")
else:
    print("No existing cache found.")

---
## Step 3: MERT Feature Extraction

- Model: m-a-p/MERT-v1-330M (~8GB VRAM)
- Layers: 12-24 averaged
- Output: 1024-dim per frame

In [None]:
import librosa
from transformers import AutoModel, AutoProcessor

class MERT330MExtractor:
    def __init__(self, cache_dir=None):
        self.target_sr = 24000
        self.use_layers = (12, 25)
        self.cache_dir = Path(cache_dir) if cache_dir else None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Loading MERT-v1-330M on {self.device}...")
        self.processor = AutoProcessor.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)
        self.model = AutoModel.from_pretrained(
            "m-a-p/MERT-v1-330M",
            output_hidden_states=True,
            trust_remote_code=True,
        ).to(self.device)
        self.model.eval()
        print(f"Model loaded. Hidden size: {self.model.config.hidden_size}")
    
    @torch.no_grad()
    def extract_from_file(self, audio_path, use_cache=True):
        audio_path = Path(audio_path)
        
        if use_cache and self.cache_dir:
            cache_path = self.cache_dir / f"{audio_path.stem}.pt"
            if cache_path.exists():
                return torch.load(cache_path, weights_only=True)
        
        audio, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
        inputs = self.processor(audio, sampling_rate=self.target_sr, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        outputs = self.model(**inputs)
        hidden_states = outputs.hidden_states[self.use_layers[0]:self.use_layers[1]]
        embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0).cpu()
        
        if use_cache and self.cache_dir:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            torch.save(embeddings, cache_path)
        
        return embeddings

In [None]:
# Extract MERT embeddings
print("="*60)
print("MERT FEATURE EXTRACTION")
print("="*60)

audio_files = sorted(AUDIO_DIR.glob('*.wav'))
cached_files = set(f.stem for f in MERT_CACHE_DIR.glob('*.pt'))
to_extract = [f for f in audio_files if f.stem not in cached_files]

print(f"Audio files: {len(audio_files)}")
print(f"Already cached: {len(cached_files)}")
print(f"To extract: {len(to_extract)}")

if to_extract:
    extractor = MERT330MExtractor(cache_dir=MERT_CACHE_DIR)
    
    failed = []
    for audio_path in tqdm(to_extract, desc="Extracting"):
        try:
            extractor.extract_from_file(audio_path)
        except Exception as e:
            failed.append((audio_path.stem, str(e)))
    
    del extractor
    torch.cuda.empty_cache()
    
    print(f"\nExtracted: {len(to_extract) - len(failed)}")
    if failed:
        print(f"Failed: {len(failed)}")
else:
    print("\nAll embeddings cached!")

print(f"Total cached: {len(list(MERT_CACHE_DIR.glob('*.pt')))}")

In [None]:
# Sync MERT cache to Google Drive
run_rclone(
    ['rclone', 'copy', str(MERT_CACHE_DIR), GDRIVE_MERT_CACHE, '-v'],
    "Syncing MERT cache to Google Drive"
)
print("Done!")

---
## Step 4: Dataset and Model

In [None]:
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity",
    "timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

DIMENSION_CATEGORIES = {
    "timing": ["timing"],
    "articulation": ["articulation_length", "articulation_touch"],
    "pedal": ["pedal_amount", "pedal_clarity"],
    "timbre": ["timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness"],
    "dynamics": ["dynamic_range"],
    "tempo_space": ["tempo", "space", "balance", "drama"],
    "emotion": ["mood_valence", "mood_energy", "mood_imagination"],
    "interpretation": ["sophistication", "interpretation"],
}


class AudioPercePianoDataset(Dataset):
    def __init__(self, mert_cache_dir, labels, fold_assignments, fold_id, mode, max_frames=1000):
        self.mert_cache_dir = Path(mert_cache_dir)
        self.max_frames = max_frames
        
        available = {p.stem for p in self.mert_cache_dir.glob('*.pt')}
        
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:
            valid_keys = set()
            for i in range(4):
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, torch.tensor(labels[k][:19], dtype=torch.float32))
                        for k in valid_keys if k in available and k in labels]
        print(f"{mode} (fold {fold_id}): {len(self.samples)} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        key, label = self.samples[idx]
        emb = torch.load(self.mert_cache_dir / f"{key}.pt", weights_only=True)
        if emb.shape[0] > self.max_frames:
            emb = emb[:self.max_frames]
        return {"embeddings": emb, "labels": label, "key": key, "length": emb.shape[0]}


def collate_fn(batch):
    embs = [b["embeddings"] for b in batch]
    labels = torch.stack([b["labels"] for b in batch])
    lengths = torch.tensor([b["length"] for b in batch])
    padded = pad_sequence(embs, batch_first=True)
    mask = torch.arange(padded.shape[1]).unsqueeze(0) < lengths.unsqueeze(1)
    return {"embeddings": padded, "attention_mask": mask, "labels": labels, "keys": [b["key"] for b in batch]}


print(f"Dataset defined. {len(PERCEPIANO_DIMENSIONS)} dimensions.")

In [None]:
class AudioPercePianoModel(pl.LightningModule):
    def __init__(self, input_dim=1024, hidden_dim=512, num_labels=19, dropout=0.2,
                 learning_rate=1e-4, weight_decay=1e-5, pooling="mean", max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.pooling = pooling
        self.max_epochs = max_epochs
        
        if pooling == "attention":
            self.attn = nn.Sequential(nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1))
        
        self.clf = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels), nn.Sigmoid(),
        )
        self.loss_fn = nn.MSELoss()
        self.val_outputs = []
    
    def forward(self, x, mask=None):
        if self.pooling == "mean":
            if mask is not None:
                m = mask.unsqueeze(-1).float()
                pooled = (x * m).sum(1) / m.sum(1).clamp(min=1)
            else:
                pooled = x.mean(1)
        elif self.pooling == "attention":
            scores = self.attn(x).squeeze(-1)
            if mask is not None:
                scores = scores.masked_fill(~mask, float('-inf'))
            w = torch.softmax(scores, dim=-1).unsqueeze(-1)
            pooled = (x * w).sum(1)
        else:
            pooled = x.mean(1)
        return self.clf(pooled)
    
    def training_step(self, batch, idx):
        loss = self.loss_fn(self(batch["embeddings"], batch["attention_mask"]), batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        preds = self(batch["embeddings"], batch["attention_mask"])
        self.log("val_loss", self.loss_fn(preds, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": preds.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.max_epochs, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


print("Model defined.")

---
## Step 5: Training

In [None]:
torch.set_float32_matmul_precision('medium')

CONFIG = {
    # Model
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'dropout': 0.2,
    'pooling': 'mean',
    
    # Optimizer
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'gradient_clip_val': 1.0,
    
    # Training
    'batch_size': 64,  # A100 80GB can handle larger batches
    'max_epochs': 200,
    'patience': 15,
    'max_frames': 1000,
    'n_folds': 4,
    'num_workers': 2,  # Safe for 4 vCPUs
}

print("Config:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Check for existing checkpoints (resume capability)
print("Checking for existing checkpoints on Google Drive...")
result = subprocess.run(['rclone', 'lsf', GDRIVE_CHECKPOINTS], capture_output=True, text=True)

if result.returncode == 0 and result.stdout.strip():
    print(f"Found checkpoints. Restoring...")
    run_rclone(
        ['rclone', 'copy', GDRIVE_CHECKPOINTS, str(CHECKPOINT_ROOT), '-v'],
        "Restoring checkpoints"
    )
    restored = list(CHECKPOINT_ROOT.glob('*.ckpt'))
    print(f"Restored {len(restored)} checkpoints")
else:
    print("No existing checkpoints found.")

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger

def extract_best_r2_from_checkpoint(ckpt_path):
    """Extract best R2 score from PyTorch Lightning checkpoint."""
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    
    # PyTorch Lightning stores callbacks with full class path as key
    callbacks = ckpt.get('callbacks', {})
    for key, value in callbacks.items():
        if 'ModelCheckpoint' in key:
            score = value.get('best_model_score')
            if score is not None:
                return float(score)
    return None

# Create log directory
LOG_DIR = CHECKPOINT_ROOT / 'logs'
LOG_DIR.mkdir(parents=True, exist_ok=True)

print("="*70)
print("4-FOLD CROSS-VALIDATION")
print("="*70)

fold_results = {}

for fold in range(CONFIG['n_folds']):
    ckpt_path = CHECKPOINT_ROOT / f'fold{fold}_best.ckpt'
    
    if ckpt_path.exists():
        r2 = extract_best_r2_from_checkpoint(ckpt_path)
        if r2 is not None:
            fold_results[fold] = r2
            print(f"Fold {fold}: SKIP (exists) R2={fold_results[fold]:+.4f}")
            continue
        else:
            print(f"Fold {fold}: Checkpoint exists but R2 not found, will retrain")
    
    print(f"\nFold {fold}: Training...")
    
    train_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "train", CONFIG['max_frames'])
    val_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "val", CONFIG['max_frames'])
    
    train_dl = DataLoader(
        train_ds, batch_size=CONFIG['batch_size'], shuffle=True, 
        collate_fn=collate_fn, num_workers=CONFIG['num_workers'], pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=CONFIG['batch_size'], shuffle=False, 
        collate_fn=collate_fn, num_workers=CONFIG['num_workers'], pin_memory=True
    )
    
    model = AudioPercePianoModel(
        CONFIG['input_dim'], CONFIG['hidden_dim'], CONFIG['num_labels'],
        CONFIG['dropout'], CONFIG['learning_rate'], CONFIG['weight_decay'], 
        CONFIG['pooling'], CONFIG['max_epochs']
    )
    
    callbacks = [
        ModelCheckpoint(dirpath=CHECKPOINT_ROOT, filename=f'fold{fold}_best', monitor='val_r2', mode='max', save_top_k=1),
        EarlyStopping(monitor='val_r2', mode='max', patience=CONFIG['patience'], verbose=True),
    ]
    
    logger = CSVLogger(save_dir=LOG_DIR, name=f'fold{fold}', version='')
    
    trainer = pl.Trainer(
        max_epochs=CONFIG['max_epochs'],
        callbacks=callbacks,
        logger=logger,
        accelerator='auto',
        devices=1,
        gradient_clip_val=CONFIG['gradient_clip_val'],
        enable_progress_bar=True,
        deterministic=True,
        log_every_n_steps=10,
    )
    
    trainer.fit(model, train_dl, val_dl)
    
    fold_results[fold] = float(callbacks[0].best_model_score or 0)
    print(f"Fold {fold} Best R2: {fold_results[fold]:+.4f}")
    
    del model, trainer
    torch.cuda.empty_cache()

# Summary
print("\n" + "="*70)
print("RESULTS")
print("="*70)
for f, r2 in sorted(fold_results.items()):
    print(f"  Fold {f}: {r2:+.4f}")
avg = np.mean(list(fold_results.values()))
std = np.std(list(fold_results.values()))
print(f"  Average: {avg:+.4f} +/- {std:.4f}")
print(f"  Target: >= 0.25")

In [None]:
# Sync checkpoints to Google Drive (critical for ephemeral storage)
run_rclone(
    ['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINTS, '-v'],
    "Syncing checkpoints to Google Drive"
)
print("Done!")

---
## Step 6: Evaluation

In [None]:
from scipy import stats
from sklearn.metrics import mean_absolute_error, mean_squared_error
import pandas as pd

print("="*70)
print("COMPREHENSIVE EVALUATION")
print("="*70)

# ============================================================================
# COLLECT PREDICTIONS
# ============================================================================
all_preds, all_labels, all_keys = [], [], []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for fold in range(CONFIG['n_folds']):
    ckpt_path = CHECKPOINT_ROOT / f'fold{fold}_best.ckpt'
    if not ckpt_path.exists():
        print(f"Warning: Fold {fold} checkpoint not found, skipping")
        continue
    
    model = AudioPercePianoModel.load_from_checkpoint(ckpt_path).to(device).eval()
    val_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "val", CONFIG['max_frames'])
    val_dl = DataLoader(val_ds, batch_size=CONFIG['batch_size'], collate_fn=collate_fn, num_workers=0)
    
    with torch.no_grad():
        for batch in val_dl:
            preds = model(batch["embeddings"].to(device), batch["attention_mask"].to(device))
            if torch.isnan(preds).any() or torch.isinf(preds).any():
                print(f"WARNING: NaN/Inf detected in fold {fold} predictions!")
            all_preds.append(preds.cpu().numpy())
            all_labels.append(batch["labels"].numpy())
            all_keys.extend(batch["keys"])
    
    del model
    torch.cuda.empty_cache()

if not all_preds:
    raise RuntimeError("No predictions collected - no checkpoints found")

all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)
all_keys = np.array(all_keys)

print(f"\nTotal samples evaluated: {len(all_preds)}")

# ============================================================================
# BASELINE COMPARISONS
# ============================================================================
print("\n" + "="*70)
print("BASELINE COMPARISONS")
print("="*70)

mean_baseline_preds = np.tile(all_labels.mean(axis=0), (len(all_labels), 1))
mean_baseline_r2 = r2_score(all_labels, mean_baseline_preds)

np.random.seed(SEED)
random_baseline_preds = np.random.rand(*all_labels.shape)
random_baseline_r2 = r2_score(all_labels, random_baseline_preds)

print(f"Random Baseline R2:     {random_baseline_r2:+.4f}")
print(f"Mean Predictor R2:      {mean_baseline_r2:+.4f} (always 0 by definition)")
print(f"Our Model R2:           {r2_score(all_labels, all_preds):+.4f}")
print(f"Improvement over mean:  {r2_score(all_labels, all_preds) - mean_baseline_r2:+.4f}")

# ============================================================================
# OVERALL METRICS
# ============================================================================
print("\n" + "="*70)
print("OVERALL METRICS")
print("="*70)

overall_r2 = r2_score(all_labels, all_preds)
overall_mae = mean_absolute_error(all_labels, all_preds)
overall_rmse = np.sqrt(mean_squared_error(all_labels, all_preds))

pearson_corrs = []
for i in range(all_labels.shape[1]):
    corr, _ = stats.pearsonr(all_labels[:, i], all_preds[:, i])
    pearson_corrs.append(corr)
avg_pearson = np.mean(pearson_corrs)

print(f"R2 Score:           {overall_r2:+.4f}")
print(f"MAE:                {overall_mae:.4f}")
print(f"RMSE:               {overall_rmse:.4f}")
print(f"Avg Pearson Corr:   {avg_pearson:.4f}")

# ============================================================================
# BOOTSTRAP CONFIDENCE INTERVALS
# ============================================================================
print("\n" + "="*70)
print("BOOTSTRAP CONFIDENCE INTERVALS (95%)")
print("="*70)

def bootstrap_r2(y_true, y_pred, n_bootstrap=1000, seed=42):
    np.random.seed(seed)
    n_samples = len(y_true)
    r2_scores = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n_samples, n_samples, replace=True)
        r2_scores.append(r2_score(y_true[idx], y_pred[idx]))
    return np.percentile(r2_scores, [2.5, 50, 97.5])

overall_ci = bootstrap_r2(all_labels, all_preds)
print(f"Overall R2: {overall_ci[1]:+.4f} [{overall_ci[0]:+.4f}, {overall_ci[2]:+.4f}]")

print("\nPer-dimension (selected):")
for d in ['timing', 'timbre_brightness', 'dynamic_range', 'pedal_amount', 'interpretation']:
    i = PERCEPIANO_DIMENSIONS.index(d)
    ci = bootstrap_r2(all_labels[:, i], all_preds[:, i])
    print(f"  {d:<22} {ci[1]:+.4f} [{ci[0]:+.4f}, {ci[2]:+.4f}]")

# ============================================================================
# CATEGORY ANALYSIS (PER-DIMENSION)
# ============================================================================
print("\n" + "="*70)
print("CATEGORY ANALYSIS")
print("="*70)

dim_metrics = {}
for i, d in enumerate(PERCEPIANO_DIMENSIONS):
    y_true, y_pred = all_labels[:, i], all_preds[:, i]
    pearson, p_val = stats.pearsonr(y_true, y_pred)
    dim_metrics[d] = {
        'r2': r2_score(y_true, y_pred),
        'mae': mean_absolute_error(y_true, y_pred),
        'rmse': np.sqrt(mean_squared_error(y_true, y_pred)),
        'pearson': pearson, 'p_value': p_val,
        'label_mean': y_true.mean(), 'label_std': y_true.std(),
        'pred_mean': y_pred.mean(), 'pred_std': y_pred.std(),
    }

for cat, dims in DIMENSION_CATEGORIES.items():
    cat_r2 = np.mean([dim_metrics[d]['r2'] for d in dims])
    cat_pearson = np.mean([dim_metrics[d]['pearson'] for d in dims])
    print(f"\n{cat.upper()}")
    print(f"  Category R2: {cat_r2:+.4f}  |  Pearson: {cat_pearson:.4f}")
    for d in dims:
        m = dim_metrics[d]
        sig = "***" if m['p_value'] < 0.001 else "**" if m['p_value'] < 0.01 else "*" if m['p_value'] < 0.05 else ""
        print(f"    {d:<22} R2={m['r2']:+.4f}  MAE={m['mae']:.3f}  r={m['pearson']:.3f}{sig}")

# ============================================================================
# PREDICTION DISTRIBUTION ANALYSIS
# ============================================================================
print("\n" + "="*70)
print("PREDICTION DISTRIBUTION ANALYSIS")
print("="*70)

print(f"\n{'Dimension':<25} {'Lbl Mean':>8} {'Lbl Std':>8} {'Prd Mean':>8} {'Prd Std':>8} {'Bias':>7}")
print("-" * 70)

distribution_issues = []
for i, d in enumerate(PERCEPIANO_DIMENSIONS):
    lbl_mean, lbl_std = all_labels[:, i].mean(), all_labels[:, i].std()
    pred_mean, pred_std = all_preds[:, i].mean(), all_preds[:, i].std()
    bias = pred_mean - lbl_mean
    std_ratio = pred_std / lbl_std if lbl_std > 0 else 0
    
    if std_ratio < 0.5:
        distribution_issues.append(f"{d}: under-dispersed (std ratio: {std_ratio:.2f})")
    elif std_ratio > 1.5:
        distribution_issues.append(f"{d}: over-dispersed (std ratio: {std_ratio:.2f})")
    if abs(bias) > 0.1:
        distribution_issues.append(f"{d}: significant bias ({bias:+.3f})")
    
    print(f"{d:<25} {lbl_mean:>8.3f} {lbl_std:>8.3f} {pred_mean:>8.3f} {pred_std:>8.3f} {bias:>+7.3f}")

avg_pred_std = np.mean([all_preds[:, i].std() for i in range(19)])
avg_lbl_std = np.mean([all_labels[:, i].std() for i in range(19)])

print(f"\nMean Regression Check:")
print(f"  Avg label std: {avg_lbl_std:.4f} | Avg pred std: {avg_pred_std:.4f} | Ratio: {avg_pred_std/avg_lbl_std:.2%}")
if avg_pred_std / avg_lbl_std < 0.7:
    print("  WARNING: Model may be suffering from mean regression")

if distribution_issues:
    print(f"\nDistribution Issues ({len(distribution_issues)}):")
    for issue in distribution_issues[:5]:
        print(f"  - {issue}")
    if len(distribution_issues) > 5:
        print(f"  ... and {len(distribution_issues) - 5} more")

# ============================================================================
# PER-SAMPLE ERROR ANALYSIS
# ============================================================================
print("\n" + "="*70)
print("PER-SAMPLE ERROR ANALYSIS")
print("="*70)

sample_mse = np.mean((all_preds - all_labels) ** 2, axis=1)
sample_mae = np.mean(np.abs(all_preds - all_labels), axis=1)
sorted_idx = np.argsort(sample_mse)

print("\nHARDEST SAMPLES (highest MSE):")
for i in sorted_idx[-5:][::-1]:
    dim_errors = np.abs(all_preds[i] - all_labels[i])
    worst_dims = np.argsort(dim_errors)[-3:][::-1]
    worst_str = ", ".join([f"{PERCEPIANO_DIMENSIONS[d]}({dim_errors[d]:.2f})" for d in worst_dims])
    print(f"  {all_keys[i][:35]:<35} MSE={sample_mse[i]:.4f} | {worst_str}")

print(f"\nError Distribution: MSE mean={sample_mse.mean():.4f} std={sample_mse.std():.4f}")
print(f"  Outliers: {(sample_mse > 2*sample_mse.mean()).sum()} samples > 2x mean, {(sample_mse > 3*sample_mse.mean()).sum()} > 3x mean")

# ============================================================================
# TRAINING CURVE ANALYSIS
# ============================================================================
print("\n" + "="*70)
print("TRAINING CURVE ANALYSIS")
print("="*70)

training_summary = {}
for fold in range(CONFIG['n_folds']):
    log_file = LOG_DIR / f'fold{fold}' / 'metrics.csv'
    if not log_file.exists():
        print(f"Fold {fold}: No log found")
        continue
    
    df = pd.read_csv(log_file)
    val_r2 = df[df['val_r2'].notna()]['val_r2']
    train_loss = df[df['train_loss'].notna()]['train_loss']
    val_loss = df[df['val_loss'].notna()]['val_loss']
    
    if len(val_r2) > 0:
        training_summary[fold] = {
            'epochs': len(val_r2),
            'best_epoch': int(val_r2.idxmax()) if pd.notna(val_r2.idxmax()) else 0,
            'best_r2': float(val_r2.max()),
            'final_r2': float(val_r2.iloc[-1]),
            'final_train_loss': float(train_loss.iloc[-1]) if len(train_loss) > 0 else None,
            'final_val_loss': float(val_loss.iloc[-1]) if len(val_loss) > 0 else None,
        }
        s = training_summary[fold]
        gap_str = ""
        if s['final_train_loss'] and s['final_val_loss']:
            gap = s['final_val_loss'] - s['final_train_loss']
            gap_str = f" | gap={gap:.4f}" if gap > 0.01 else ""
        print(f"Fold {fold}: {s['epochs']} epochs, best R2={s['best_r2']:+.4f} @ epoch {s['best_epoch']}{gap_str}")

if training_summary:
    print(f"\nAvg epochs to convergence: {np.mean([s['epochs'] for s in training_summary.values()]):.1f}")

# ============================================================================
# FINAL SUMMARY
# ============================================================================
print("\n" + "="*70)
print("FINAL SUMMARY")
print("="*70)

print(f"\nModel: MERT-330M + MLP ({CONFIG['pooling']} pooling)")
print(f"Training: {CONFIG['n_folds']}-fold CV, batch={CONFIG['batch_size']}, seed={SEED}")

print(f"\n--- PERFORMANCE ---")
print(f"Average R2:     {avg:+.4f} +/- {std:.4f}")
print(f"95% CI:         [{overall_ci[0]:+.4f}, {overall_ci[2]:+.4f}]")
print(f"MAE:            {overall_mae:.4f}")
print(f"Pearson:        {avg_pearson:.4f}")

print(f"\n--- VS BASELINES ---")
print(f"Random:         {random_baseline_r2:+.4f}")
print(f"Mean predictor: {mean_baseline_r2:+.4f}")
print(f"Our model:      {overall_r2:+.4f} (+{overall_r2 - mean_baseline_r2:.4f})")

print(f"\n--- TARGET ---")
print(f"Target: R2 >= 0.25")
print(f"Status: {'PASS' if avg >= 0.25 else 'BELOW TARGET'}")

if avg < 0.25:
    print("\nSuggested improvements:")
    if avg_pred_std / avg_lbl_std < 0.7:
        print("  - Mean regression detected: try correlation loss")
    print("  - Try attention pooling: CONFIG['pooling'] = 'attention'")
    print("  - Try LSTM/Transformer on MERT frames")

print("\n" + "="*70)

In [None]:
# Save comprehensive results
print("\n" + "="*70)
print("SAVING RESULTS")
print("="*70)

# Compute bootstrap CI for overall
overall_ci = bootstrap_r2(all_labels, all_preds)

results = {
    "summary": {
        "avg_r2": float(avg),
        "std_r2": float(std),
        "overall_r2": float(overall_r2),
        "r2_ci_95": [float(overall_ci[0]), float(overall_ci[2])],
        "overall_mae": float(overall_mae),
        "overall_rmse": float(overall_rmse),
        "avg_pearson": float(avg_pearson),
        "n_samples": int(len(all_preds)),
        "target": 0.25,
        "target_met": bool(avg >= 0.25),
    },
    "fold_results": {str(k): float(v) for k, v in fold_results.items()},
    "per_dimension": {
        d: {k: float(v) if isinstance(v, (np.floating, float)) else v 
            for k, v in m.items()}
        for d, m in dim_metrics.items()
    },
    "baselines": {
        "random_r2": float(random_baseline_r2),
        "mean_predictor_r2": float(mean_baseline_r2),
        "improvement_over_mean": float(overall_r2 - mean_baseline_r2),
    },
    "distribution_analysis": {
        "avg_label_std": float(avg_lbl_std),
        "avg_pred_std": float(avg_pred_std),
        "dispersion_ratio": float(avg_pred_std / avg_lbl_std),
        "issues": distribution_issues,
    },
    "training_summary": training_summary,
    "hardest_samples": [
        {"key": all_keys[i], "mse": float(sample_mse[i]), "mae": float(sample_mae[i])}
        for i in sorted_idx[-10:][::-1]
    ],
    "config": CONFIG,
    "seed": SEED,
}

# Save to JSON
results_file = CHECKPOINT_ROOT / "results_comprehensive.json"
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)
print(f"Saved: {results_file}")

# Sync to Google Drive
run_rclone(
    ['rclone', 'copy', str(results_file), GDRIVE_CHECKPOINTS, '-v'],
    "Syncing comprehensive results"
)

# Also sync training logs
run_rclone(
    ['rclone', 'copy', str(LOG_DIR), f"{GDRIVE_CHECKPOINTS}/logs", '-v'],
    "Syncing training logs"
)

print("\nAll results saved and synced!")

In [None]:
print("="*70)
print("FINAL SUMMARY")
print("="*70)

print(f"\nModel: MERT-330M + MLP ({CONFIG['pooling']} pooling)")
print(f"Training: {CONFIG['n_folds']}-fold CV, {CONFIG['max_epochs']} max epochs, batch={CONFIG['batch_size']}")

print(f"\n--- PERFORMANCE ---")
print(f"Average R2:     {avg:+.4f} +/- {std:.4f}")
print(f"95% CI:         [{overall_ci[0]:+.4f}, {overall_ci[2]:+.4f}]")
print(f"Overall R2:     {overall_r2:+.4f}")
print(f"MAE:            {overall_mae:.4f}")
print(f"Pearson:        {avg_pearson:.4f}")

print(f"\n--- VS BASELINES ---")
print(f"Random:         {random_baseline_r2:+.4f}")
print(f"Mean predictor: {mean_baseline_r2:+.4f}")
print(f"Improvement:    {overall_r2 - mean_baseline_r2:+.4f}")

print(f"\n--- TARGET ---")
print(f"Target R2:      >= 0.25")
print(f"Status:         {'PASS' if avg >= 0.25 else 'BELOW TARGET'}")

if avg >= 0.25:
    print("\nAudio baseline validation PASSED!")
    print("Audio features provide meaningful signal for piano performance evaluation.")
    print("\nNext steps:")
    print("  1. Proceed to Phase B (Pianoteq rendering)")
    print("  2. Compare audio vs symbolic baseline performance")
    print("  3. Design fusion architecture")
else:
    print("\nBelow target. Analysis suggests:")
    if avg_pred_std / avg_lbl_std < 0.7:
        print("  - Mean regression issue: try different loss (e.g., correlation loss)")
    if any('under-dispersed' in issue for issue in distribution_issues):
        print("  - Predictions lack variance: try attention pooling or temporal modeling")
    print("\nPotential improvements:")
    print("  - Try attention pooling (set CONFIG['pooling'] = 'attention')")
    print("  - Try LSTM/Transformer on MERT frames")
    print("  - Increase model capacity (hidden_dim)")
    print("  - Add multi-task auxiliary losses")

print("\n" + "="*70)