# Audio Baseline for PercePiano (MERT-330M)

Train audio baseline using MERT-330M embeddings on Thunder Compute.

## What This Notebook Does

1. Download pre-rendered WAV files from Google Drive
2. Extract MERT-330M embeddings (GPU required)
3. Train 4-fold cross-validation
4. Evaluate and analyze results

## Target: R2 >= 0.25

## Step 1: Environment Setup

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. MERT extraction will be very slow.")

In [None]:
# Install rclone
!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
# Install dependencies
!pip install transformers librosa soundfile torchaudio pytorch-lightning --quiet

import subprocess
result = subprocess.run(['rclone', 'listremotes'], capture_output=True, text=True)
if 'gdrive:' not in result.stdout:
    raise RuntimeError("rclone not configured. Run 'rclone config' to set up gdrive remote.")
print("rclone 'gdrive' remote: CONFIGURED")

In [None]:
# Core imports
import json
import subprocess
import warnings
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from sklearn.metrics import r2_score
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

## Step 2: Download Data from Google Drive

In [None]:
# Paths
DATA_ROOT = Path('/tmp/audio_baseline')
AUDIO_DIR = DATA_ROOT / 'percepiano_rendered'
MERT_CACHE_DIR = DATA_ROOT / 'mert_embeddings'
CHECKPOINT_ROOT = DATA_ROOT / 'checkpoints'
LABEL_DIR = DATA_ROOT / 'labels'

# Google Drive paths
GDRIVE_AUDIO = 'gdrive:crescendai_data/audio_baseline/percepiano_rendered'
GDRIVE_LABELS = 'gdrive:crescendai_data/percepiano_labels'
GDRIVE_FOLDS = 'gdrive:crescendai_data/audio_baseline/audio_fold_assignments.json'
GDRIVE_CHECKPOINTS = 'gdrive:crescendai_data/checkpoints/audio_baseline'
GDRIVE_MERT_CACHE = 'gdrive:crescendai_data/audio_baseline/mert_embeddings'

# Create directories
for d in [AUDIO_DIR, MERT_CACHE_DIR, CHECKPOINT_ROOT, LABEL_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print(f"Data root: {DATA_ROOT}")

In [None]:
# Download pre-rendered audio from Google Drive
print("Downloading pre-rendered audio files...")
print(f"  Source: {GDRIVE_AUDIO}")
print(f"  Destination: {AUDIO_DIR}")

subprocess.run(
    ['rclone', 'copy', GDRIVE_AUDIO, str(AUDIO_DIR), '--progress'],
    capture_output=False
)

wav_count = len(list(AUDIO_DIR.glob('*.wav')))
print(f"\nDownloaded {wav_count} WAV files")

if wav_count == 0:
    raise RuntimeError("No WAV files downloaded! Run prepare_audio_baseline.py locally first.")

In [None]:
# Download labels and fold assignments
print("Downloading labels...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_LABELS, str(LABEL_DIR), '--progress'],
    capture_output=False
)

print("\nDownloading fold assignments...")
subprocess.run(
    ['rclone', 'copy', GDRIVE_FOLDS, str(DATA_ROOT), '--progress'],
    capture_output=False
)

# Verify
LABEL_FILE = LABEL_DIR / 'label_2round_mean_reg_19_with0_rm_highstd0.json'
FOLD_FILE = DATA_ROOT / 'audio_fold_assignments.json'

if not LABEL_FILE.exists():
    raise FileNotFoundError(f"Label file not found: {LABEL_FILE}")

with open(LABEL_FILE) as f:
    labels = json.load(f)
print(f"Labels: {len(labels)} segments")

if not FOLD_FILE.exists():
    raise FileNotFoundError(f"Fold file not found: {FOLD_FILE}")

with open(FOLD_FILE) as f:
    fold_assignments = json.load(f)
print(f"\nFold statistics:")
for fold_name, keys in fold_assignments.items():
    print(f"  {fold_name}: {len(keys)} samples")

In [None]:
# Check for existing MERT cache (resume capability)
print("Checking for existing MERT cache on Google Drive...")
result = subprocess.run(
    ['rclone', 'lsf', GDRIVE_MERT_CACHE],
    capture_output=True, text=True
)

if result.returncode == 0 and result.stdout.strip():
    remote_files = [f for f in result.stdout.strip().split('\n') if f.endswith('.pt')]
    if remote_files:
        print(f"Found {len(remote_files)} cached embeddings. Restoring...")
        subprocess.run(
            ['rclone', 'copy', GDRIVE_MERT_CACHE, str(MERT_CACHE_DIR), '--progress'],
            capture_output=False
        )
        print(f"Restored {len(list(MERT_CACHE_DIR.glob('*.pt')))} embeddings")
else:
    print("No existing cache found.")

---
## Step 3: MERT Feature Extraction

- Model: m-a-p/MERT-v1-330M (~8GB VRAM)
- Layers: 12-24 averaged
- Output: 1024-dim per frame

In [None]:
import librosa
from transformers import AutoModel, AutoProcessor

class MERT330MExtractor:
    def __init__(self, cache_dir=None):
        self.target_sr = 24000
        self.use_layers = (12, 25)
        self.cache_dir = Path(cache_dir) if cache_dir else None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Loading MERT-v1-330M on {self.device}...")
        self.processor = AutoProcessor.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)
        self.model = AutoModel.from_pretrained(
            "m-a-p/MERT-v1-330M",
            output_hidden_states=True,
            trust_remote_code=True,
        ).to(self.device)
        self.model.eval()
        print(f"Model loaded. Hidden size: {self.model.config.hidden_size}")
    
    @torch.no_grad()
    def extract_from_file(self, audio_path, use_cache=True):
        audio_path = Path(audio_path)
        
        if use_cache and self.cache_dir:
            cache_path = self.cache_dir / f"{audio_path.stem}.pt"
            if cache_path.exists():
                return torch.load(cache_path, weights_only=True)
        
        audio, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
        inputs = self.processor(audio, sampling_rate=self.target_sr, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        outputs = self.model(**inputs)
        hidden_states = outputs.hidden_states[self.use_layers[0]:self.use_layers[1]]
        embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0).cpu()
        
        if use_cache and self.cache_dir:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            torch.save(embeddings, cache_path)
        
        return embeddings

In [None]:
# Extract MERT embeddings
print("="*60)
print("MERT FEATURE EXTRACTION")
print("="*60)

audio_files = sorted(AUDIO_DIR.glob('*.wav'))
cached_files = set(f.stem for f in MERT_CACHE_DIR.glob('*.pt'))
to_extract = [f for f in audio_files if f.stem not in cached_files]

print(f"Audio files: {len(audio_files)}")
print(f"Already cached: {len(cached_files)}")
print(f"To extract: {len(to_extract)}")

if to_extract:
    extractor = MERT330MExtractor(cache_dir=MERT_CACHE_DIR)
    
    failed = []
    for audio_path in tqdm(to_extract, desc="Extracting"):
        try:
            extractor.extract_from_file(audio_path)
        except Exception as e:
            failed.append((audio_path.stem, str(e)))
    
    del extractor
    torch.cuda.empty_cache()
    
    print(f"\nExtracted: {len(to_extract) - len(failed)}")
    if failed:
        print(f"Failed: {len(failed)}")
else:
    print("\nAll embeddings cached!")

print(f"Total cached: {len(list(MERT_CACHE_DIR.glob('*.pt')))}")

In [None]:
# Sync MERT cache to Google Drive
print("Syncing MERT cache to Google Drive...")
subprocess.run(
    ['rclone', 'copy', str(MERT_CACHE_DIR), GDRIVE_MERT_CACHE, '--progress'],
    capture_output=False
)
print("Done!")

---
## Step 4: Dataset and Model

In [None]:
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity",
    "timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

DIMENSION_CATEGORIES = {
    "timing": ["timing"],
    "articulation": ["articulation_length", "articulation_touch"],
    "pedal": ["pedal_amount", "pedal_clarity"],
    "timbre": ["timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness"],
    "dynamics": ["dynamic_range"],
    "tempo_space": ["tempo", "space", "balance", "drama"],
    "emotion": ["mood_valence", "mood_energy", "mood_imagination"],
    "interpretation": ["sophistication", "interpretation"],
}


class AudioPercePianoDataset(Dataset):
    def __init__(self, mert_cache_dir, labels, fold_assignments, fold_id, mode, max_frames=1000):
        self.mert_cache_dir = Path(mert_cache_dir)
        self.max_frames = max_frames
        
        available = {p.stem for p in self.mert_cache_dir.glob('*.pt')}
        
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:
            valid_keys = set()
            for i in range(4):
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, torch.tensor(labels[k][:19], dtype=torch.float32))
                        for k in valid_keys if k in available and k in labels]
        print(f"{mode} (fold {fold_id}): {len(self.samples)} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        key, label = self.samples[idx]
        emb = torch.load(self.mert_cache_dir / f"{key}.pt", weights_only=True)
        if emb.shape[0] > self.max_frames:
            emb = emb[:self.max_frames]
        return {"embeddings": emb, "labels": label, "key": key, "length": emb.shape[0]}


def collate_fn(batch):
    embs = [b["embeddings"] for b in batch]
    labels = torch.stack([b["labels"] for b in batch])
    lengths = torch.tensor([b["length"] for b in batch])
    padded = pad_sequence(embs, batch_first=True)
    mask = torch.arange(padded.shape[1]).unsqueeze(0) < lengths.unsqueeze(1)
    return {"embeddings": padded, "attention_mask": mask, "labels": labels, "keys": [b["key"] for b in batch]}


print(f"Dataset defined. {len(PERCEPIANO_DIMENSIONS)} dimensions.")

In [None]:
class AudioPercePianoModel(pl.LightningModule):
    def __init__(self, input_dim=1024, hidden_dim=512, num_labels=19, dropout=0.2,
                 learning_rate=1e-4, weight_decay=1e-5, pooling="mean"):
        super().__init__()
        self.save_hyperparameters()
        self.lr = learning_rate
        self.wd = weight_decay
        self.pooling = pooling
        
        if pooling == "attention":
            self.attn = nn.Sequential(nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1))
        
        self.clf = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels), nn.Sigmoid(),
        )
        self.loss_fn = nn.MSELoss()
        self.val_outputs = []
    
    def forward(self, x, mask=None):
        if self.pooling == "mean":
            if mask is not None:
                m = mask.unsqueeze(-1).float()
                pooled = (x * m).sum(1) / m.sum(1).clamp(min=1)
            else:
                pooled = x.mean(1)
        elif self.pooling == "attention":
            scores = self.attn(x).squeeze(-1)
            if mask is not None:
                scores = scores.masked_fill(~mask, float('-inf'))
            w = torch.softmax(scores, dim=-1).unsqueeze(-1)
            pooled = (x * w).sum(1)
        else:
            pooled = x.mean(1)
        return self.clf(pooled)
    
    def training_step(self, batch, idx):
        loss = self.loss_fn(self(batch["embeddings"], batch["attention_mask"]), batch["labels"])
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, idx):
        preds = self(batch["embeddings"], batch["attention_mask"])
        self.log("val_loss", self.loss_fn(preds, batch["labels"]), prog_bar=True)
        self.val_outputs.append({"p": preds.cpu(), "l": batch["labels"].cpu()})
    
    def on_validation_epoch_end(self):
        if self.val_outputs:
            p = torch.cat([x["p"] for x in self.val_outputs]).numpy()
            l = torch.cat([x["l"] for x in self.val_outputs]).numpy()
            self.log("val_r2", r2_score(l, p), prog_bar=True)
            self.val_outputs.clear()
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100, eta_min=1e-6)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}


print("Model defined.")

---
## Step 5: Training

In [None]:
torch.set_float32_matmul_precision('medium')

CONFIG = {
    'input_dim': 1024,
    'hidden_dim': 512,
    'num_labels': 19,
    'dropout': 0.2,
    'pooling': 'mean',
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'batch_size': 16,
    'max_epochs': 100,
    'patience': 15,
    'max_frames': 1000,
    'n_folds': 4,
}

print("Config:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Check for existing checkpoints
print("Checking for existing checkpoints...")
result = subprocess.run(['rclone', 'lsf', GDRIVE_CHECKPOINTS], capture_output=True, text=True)

if result.returncode == 0 and result.stdout.strip():
    print(f"Found checkpoints. Restoring...")
    subprocess.run(['rclone', 'copy', GDRIVE_CHECKPOINTS, str(CHECKPOINT_ROOT), '--progress'])
else:
    print("No existing checkpoints.")

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

print("="*70)
print("4-FOLD CROSS-VALIDATION")
print("="*70)

fold_results = {}

for fold in range(CONFIG['n_folds']):
    ckpt_path = CHECKPOINT_ROOT / f'fold{fold}_best.ckpt'
    
    if ckpt_path.exists():
        ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        r2 = ckpt.get('callbacks', {}).get('ModelCheckpoint', {}).get('best_model_score', 0)
        fold_results[fold] = float(r2) if r2 else 0.0
        print(f"Fold {fold}: SKIP (exists) R2={fold_results[fold]:+.4f}")
        continue
    
    print(f"\nFold {fold}: Training...")
    
    train_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "train", CONFIG['max_frames'])
    val_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "val", CONFIG['max_frames'])
    
    train_dl = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4, pin_memory=True)
    
    model = AudioPercePianoModel(
        CONFIG['input_dim'], CONFIG['hidden_dim'], CONFIG['num_labels'],
        CONFIG['dropout'], CONFIG['learning_rate'], CONFIG['weight_decay'], CONFIG['pooling']
    )
    
    callbacks = [
        ModelCheckpoint(dirpath=CHECKPOINT_ROOT, filename=f'fold{fold}_best', monitor='val_r2', mode='max', save_top_k=1),
        EarlyStopping(monitor='val_r2', mode='max', patience=CONFIG['patience'], verbose=True),
    ]
    
    trainer = pl.Trainer(
        max_epochs=CONFIG['max_epochs'],
        callbacks=callbacks,
        accelerator='auto',
        devices=1,
        enable_progress_bar=True,
    )
    
    trainer.fit(model, train_dl, val_dl)
    
    fold_results[fold] = float(callbacks[0].best_model_score or 0)
    print(f"Fold {fold} Best R2: {fold_results[fold]:+.4f}")
    
    del model, trainer
    torch.cuda.empty_cache()

# Summary
print("\n" + "="*70)
print("RESULTS")
print("="*70)
for f, r2 in sorted(fold_results.items()):
    print(f"  Fold {f}: {r2:+.4f}")
avg = np.mean(list(fold_results.values()))
std = np.std(list(fold_results.values()))
print(f"  Average: {avg:+.4f} +/- {std:.4f}")
print(f"  Target: >= 0.25")

In [None]:
# Sync checkpoints
print("Syncing checkpoints to Google Drive...")
subprocess.run(['rclone', 'copy', str(CHECKPOINT_ROOT), GDRIVE_CHECKPOINTS, '--progress'])
print("Done!")

---
## Step 6: Evaluation

In [None]:
print("="*70)
print("PER-DIMENSION ANALYSIS")
print("="*70)

all_preds, all_labels = [], []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for fold in range(CONFIG['n_folds']):
    ckpt_path = CHECKPOINT_ROOT / f'fold{fold}_best.ckpt'
    if not ckpt_path.exists():
        continue
    
    model = AudioPercePianoModel.load_from_checkpoint(ckpt_path).to(device).eval()
    val_ds = AudioPercePianoDataset(MERT_CACHE_DIR, labels, fold_assignments, fold, "val", CONFIG['max_frames'])
    val_dl = DataLoader(val_ds, batch_size=CONFIG['batch_size'], collate_fn=collate_fn, num_workers=0)
    
    with torch.no_grad():
        for batch in val_dl:
            preds = model(batch["embeddings"].to(device), batch["attention_mask"].to(device))
            all_preds.append(preds.cpu().numpy())
            all_labels.append(batch["labels"].numpy())

all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

print(f"\nSamples: {len(all_preds)}")
print(f"Overall R2: {r2_score(all_labels, all_preds):+.4f}")

print("\nPer-dimension R2:")
dim_r2 = {}
for i, d in enumerate(PERCEPIANO_DIMENSIONS):
    r2 = r2_score(all_labels[:, i], all_preds[:, i])
    dim_r2[d] = r2

for d, r2 in sorted(dim_r2.items(), key=lambda x: x[1], reverse=True):
    print(f"  {d:<25} {r2:+.4f}")

In [None]:
print("\nCategory Analysis:")
for cat, dims in DIMENSION_CATEGORIES.items():
    cat_r2 = np.mean([dim_r2[d] for d in dims])
    print(f"  {cat:<15} {cat_r2:+.4f}")

print("\nHypothesis:")
print(f"  Timbre (audio advantage): {np.mean([dim_r2[d] for d in DIMENSION_CATEGORIES['timbre']]):+.4f}")
print(f"  Pedal (audio advantage):  {np.mean([dim_r2[d] for d in DIMENSION_CATEGORIES['pedal']]):+.4f}")
print(f"  Timing (symbolic better): {dim_r2['timing']:+.4f}")

In [None]:
# Save results
results = {
    "fold_results": fold_results,
    "avg_r2": float(avg),
    "std_r2": float(std),
    "per_dimension_r2": {k: float(v) for k, v in dim_r2.items()},
    "overall_r2": float(r2_score(all_labels, all_preds)),
}

with open(CHECKPOINT_ROOT / "results.json", 'w') as f:
    json.dump(results, f, indent=2)

subprocess.run(['rclone', 'copy', str(CHECKPOINT_ROOT / "results.json"), GDRIVE_CHECKPOINTS])
print("Results saved and synced!")

In [None]:
print("="*70)
print("SUMMARY")
print("="*70)
print(f"Average R2: {avg:+.4f} +/- {std:.4f}")
print(f"Target: >= 0.25")
print(f"Status: {'PASS' if avg >= 0.25 else 'BELOW TARGET'}")

if avg >= 0.20:
    print("\nPhase A validation PASSED. Proceed to Phase B (Pianoteq).")
else:
    print("\nNext steps:")
    print("  - Try attention pooling")
    print("  - Try LSTM on MERT frames")
print("="*70)