# Audio Baseline for PercePiano (MERT-330M)

This notebook establishes the **first audio baseline** for piano performance evaluation using the PercePiano dataset with MERT embeddings.

## Research Question

> Can audio representations (MERT) capture perceptual dimensions that symbolic MIDI representations miss?

## Hypothesis

Based on our research (RESEARCH_v4.md), we expect audio to excel on:
- **Timbre dimensions** (variety, depth, brightness) - MIDI cannot capture overtones/harmonics
- **Pedal dimensions** (amount, clarity) - Pedal effects are acoustic phenomena
- **Emotional dimensions** (mood_energy, mood_imagination) - Require timbral context

## Attribution

> **PercePiano: Piano Performance Evaluation Dataset with Multi-level Perceptual Features**  
> Park, Kim et al., Nature Scientific Reports 2024  
> Paper: https://pmc.ncbi.nlm.nih.gov/articles/PMC11450231/

> **MERT: Acoustic Music Understanding Model with Large-Scale Self-supervised Training**  
> Li et al., ICLR 2024  
> HuggingFace: https://huggingface.co/m-a-p/MERT-v1-330M

## Success Criteria and Targets

### Phase 1: Audio Rendering
| Metric | Target | Notes |
|--------|--------|-------|
| Files rendered | 1,202 | All PercePiano segments |
| Audio quality | 44.1kHz, 16-bit | Standard CD quality |
| Rendering time | < 2 hours | Batch processing |

### Phase 2: MERT Extraction
| Metric | Target | Notes |
|--------|--------|-------|
| Embeddings extracted | 1,202 | All segments |
| Embedding dimension | 1024 | MERT-330M hidden size |
| GPU memory | < 8GB | Fits on T4/consumer GPU |

### Phase 3: Baseline Model
| Metric | Target | Stretch | Notes |
|--------|--------|---------|-------|
| Overall R2 | >= 0.25 | >= 0.35 | Competitive with symbolic |
| Timbre dims R2 | >= 0.30 | >= 0.40 | Audio should excel here |
| Pedal dims R2 | >= 0.25 | >= 0.35 | Pedal is acoustic phenomenon |

### Phase 4: Comparison to Symbolic
| Dimension Category | Expected Winner | Hypothesis |
|-------------------|-----------------|------------|
| Timing, Articulation | Symbolic | Precise onset/offset in MIDI |
| Timbre (4 dims) | **Audio** | MIDI has no timbre info |
| Pedal (2 dims) | **Audio** | Acoustic resonance effects |
| Dynamics | Tie | Both capture velocity/loudness |
| Emotion/Mood | Audio | Requires timbral context |

---
## Step 1: Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. MERT extraction will be slow.")

In [None]:
# Install dependencies
import os
import subprocess
import sys

# Check if running in Colab/remote
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install uv
    !curl -LsSf https://astral.sh/uv/install.sh | sh
    os.environ['PATH'] = f"{os.environ['HOME']}/.cargo/bin:{os.environ['PATH']}"
    
    # Clone repository
    if not os.path.exists('/tmp/crescendai'):
        !git clone https://github.com/Jai-Dhiman/crescendai.git /tmp/crescendai
    %cd /tmp/crescendai/model
    !git pull
    !uv pip install --system -e .

# Install audio-specific dependencies
!pip install transformers librosa soundfile torchaudio rich

# Check FluidSynth installation
result = subprocess.run(['which', 'fluidsynth'], capture_output=True, text=True)
if result.returncode == 0:
    print(f"FluidSynth found: {result.stdout.strip()}")
else:
    print("FluidSynth not found. Install with:")
    print("  macOS: brew install fluidsynth")
    print("  Ubuntu: apt-get install fluidsynth")
    print("  Colab: apt-get install -y fluidsynth")

In [None]:
# Core imports
import json
import pickle
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import r2_score
from rich.console import Console
from rich.table import Table
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')
console = Console()

print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")

---
## Step 2: Configure Paths

In [None]:
from pathlib import Path
import subprocess

# Detect environment
if IN_COLAB:
    PROJECT_ROOT = Path('/tmp/crescendai/model')
else:
    # Local development
    PROJECT_ROOT = Path('.').resolve().parent
    if not (PROJECT_ROOT / 'src').exists():
        PROJECT_ROOT = Path('/Users/jdhiman/Documents/crescendai/model')

# Data paths
PERCEPIANO_ROOT = PROJECT_ROOT / 'data' / 'raw' / 'PercePiano'
MIDI_DIR = PERCEPIANO_ROOT / 'virtuoso' / 'data' / 'all_2rounds'
LABEL_FILE = PERCEPIANO_ROOT / 'label_2round_mean_reg_19_with0_rm_highstd0.json'

# Audio output paths
AUDIO_DIR = PROJECT_ROOT / 'data' / 'audio' / 'percepiano_rendered'
MERT_CACHE_DIR = PROJECT_ROOT / 'data' / 'cache' / 'mert_embeddings'
SOUNDFONT_PATH = PROJECT_ROOT / 'data' / 'soundfonts' / 'SalamanderGrandPiano.sf2'

# Checkpoint paths
CHECKPOINT_ROOT = Path('/tmp/checkpoints/audio_baseline')
LOG_ROOT = Path('/tmp/logs/audio_baseline')

# Symbolic baseline results (for comparison)
SYMBOLIC_RESULTS_PATH = PROJECT_ROOT / 'data' / 'cache' / 'symbolic_baseline_results.json'

# Create directories
AUDIO_DIR.mkdir(parents=True, exist_ok=True)
MERT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
LOG_ROOT.mkdir(parents=True, exist_ok=True)
SOUNDFONT_PATH.parent.mkdir(parents=True, exist_ok=True)

print("="*60)
print("AUDIO BASELINE CONFIGURATION")
print("="*60)
print(f"Project root: {PROJECT_ROOT}")
print(f"MIDI source: {MIDI_DIR}")
print(f"Audio output: {AUDIO_DIR}")
print(f"MERT cache: {MERT_CACHE_DIR}")
print(f"Soundfont: {SOUNDFONT_PATH}")

# Verify paths
if MIDI_DIR.exists():
    midi_count = len(list(MIDI_DIR.glob('*.mid')))
    print(f"\nMIDI files found: {midi_count}")
else:
    print(f"\n[ERROR] MIDI directory not found: {MIDI_DIR}")

if LABEL_FILE.exists():
    with open(LABEL_FILE) as f:
        labels = json.load(f)
    print(f"Labels found: {len(labels)} segments")
else:
    print(f"[ERROR] Label file not found: {LABEL_FILE}")

---
## Step 3: Audio Rendering Pipeline

### Goal: Render all 1,202 MIDI segments to high-quality audio

We use FluidSynth with the Salamander Grand Piano soundfont:
- **Why FluidSynth?** Free, scriptable, good quality for research
- **Why Salamander?** Best free piano soundfont, recorded from Yamaha C5
- **Alternative**: Pianoteq (better half-pedal, costs $$$)

### Success Criteria
- All 1,202 segments rendered
- 44.1kHz, 16-bit WAV format
- < 2 hours total rendering time

In [None]:
# Download Salamander Grand Piano soundfont if not present
import urllib.request
import tarfile

SOUNDFONT_URL = "https://freepats.zenvoid.org/Piano/SalamanderGrandPiano/SalamanderGrandPianoV3+20161209.tar.xz"

if not SOUNDFONT_PATH.exists():
    print("Downloading Salamander Grand Piano soundfont...")
    print("(This is ~400MB, may take a few minutes)")
    
    tar_path = SOUNDFONT_PATH.parent / "salamander.tar.xz"
    
    # Download
    urllib.request.urlretrieve(SOUNDFONT_URL, tar_path)
    print(f"Downloaded to {tar_path}")
    
    # Extract
    print("Extracting...")
    subprocess.run(['tar', '-xf', str(tar_path), '-C', str(SOUNDFONT_PATH.parent)], check=True)
    
    # Find the .sf2 file
    sf2_files = list(SOUNDFONT_PATH.parent.rglob('*.sf2'))
    if sf2_files:
        # Move to expected location
        sf2_files[0].rename(SOUNDFONT_PATH)
        print(f"Soundfont ready: {SOUNDFONT_PATH}")
    else:
        print("[ERROR] No .sf2 file found in archive")
    
    # Cleanup
    tar_path.unlink(missing_ok=True)
else:
    print(f"Soundfont already exists: {SOUNDFONT_PATH}")
    print(f"Size: {SOUNDFONT_PATH.stat().st_size / 1e6:.1f} MB")

In [None]:
import subprocess
import tempfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

def render_midi_to_wav(
    midi_path: Path,
    wav_path: Path,
    soundfont_path: Path,
    sample_rate: int = 44100,
) -> bool:
    """
    Render a MIDI file to WAV using FluidSynth.
    
    Args:
        midi_path: Path to input MIDI file
        wav_path: Path to output WAV file
        soundfont_path: Path to .sf2 soundfont
        sample_rate: Output sample rate (default 44100)
    
    Returns:
        True if successful, False otherwise
    """
    try:
        wav_path.parent.mkdir(parents=True, exist_ok=True)
        
        result = subprocess.run(
            [
                'fluidsynth',
                '-ni',                    # Non-interactive
                str(soundfont_path),      # Soundfont
                str(midi_path),           # MIDI file
                '-F', str(wav_path),      # Output file
                '-r', str(sample_rate),   # Sample rate
                '-g', '0.8',              # Gain (avoid clipping)
            ],
            capture_output=True,
            text=True,
            timeout=60,  # 1 minute timeout per file
        )
        
        if result.returncode != 0:
            print(f"Error rendering {midi_path.name}: {result.stderr}")
            return False
        
        return wav_path.exists()
    
    except subprocess.TimeoutExpired:
        print(f"Timeout rendering {midi_path.name}")
        return False
    except Exception as e:
        print(f"Exception rendering {midi_path.name}: {e}")
        return False


def batch_render_midi(
    midi_dir: Path,
    output_dir: Path,
    soundfont_path: Path,
    label_keys: List[str],
    max_workers: int = 4,
    skip_existing: bool = True,
) -> Tuple[int, int]:
    """
    Batch render MIDI files to WAV.
    
    Args:
        midi_dir: Directory containing MIDI files
        output_dir: Directory for output WAV files
        soundfont_path: Path to soundfont
        label_keys: List of segment keys (to match labels)
        max_workers: Number of parallel workers
        skip_existing: Skip files that already exist
    
    Returns:
        Tuple of (successful, failed) counts
    """
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Build list of files to render
    to_render = []
    for key in label_keys:
        midi_path = midi_dir / f"{key}.mid"
        wav_path = output_dir / f"{key}.wav"
        
        if skip_existing and wav_path.exists():
            continue
        
        if midi_path.exists():
            to_render.append((midi_path, wav_path))
        else:
            print(f"MIDI not found: {key}")
    
    print(f"Files to render: {len(to_render)}")
    print(f"Already rendered: {len(label_keys) - len(to_render)}")
    
    if not to_render:
        return len(label_keys), 0
    
    # Render in parallel
    successful = 0
    failed = 0
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                render_midi_to_wav, midi_path, wav_path, soundfont_path
            ): midi_path.stem
            for midi_path, wav_path in to_render
        }
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Rendering"):
            if future.result():
                successful += 1
            else:
                failed += 1
    
    return successful, failed

print("Rendering functions defined.")

In [None]:
# Render all MIDI files
print("="*60)
print("AUDIO RENDERING")
print("="*60)

# Get all labeled segment keys
with open(LABEL_FILE) as f:
    labels = json.load(f)
label_keys = list(labels.keys())
print(f"Total labeled segments: {len(label_keys)}")

# Check soundfont
if not SOUNDFONT_PATH.exists():
    raise FileNotFoundError(f"Soundfont not found: {SOUNDFONT_PATH}")

# Check FluidSynth
result = subprocess.run(['which', 'fluidsynth'], capture_output=True)
if result.returncode != 0:
    raise RuntimeError("FluidSynth not installed. Run: brew install fluidsynth")

# Render
successful, failed = batch_render_midi(
    midi_dir=MIDI_DIR,
    output_dir=AUDIO_DIR,
    soundfont_path=SOUNDFONT_PATH,
    label_keys=label_keys,
    max_workers=4,
    skip_existing=True,
)

print("\n" + "="*60)
print("RENDERING COMPLETE")
print("="*60)
print(f"Successful: {successful}")
print(f"Failed: {failed}")
print(f"Total WAV files: {len(list(AUDIO_DIR.glob('*.wav')))}")

if failed > 0:
    print(f"\n[WARNING] {failed} files failed to render. Check logs above.")

---
## Step 4: MERT Feature Extraction

### Goal: Extract MERT-330M embeddings for all audio segments

MERT (Music Understanding Model) is pretrained on 160K hours of music with:
- **Acoustic teacher**: RVQ-VAE for acoustic features
- **Musical teacher**: CQT for pitch/harmonic structure

We extract embeddings from layers 12-24 (higher layers capture performance quality).

### Success Criteria
- All 1,202 embeddings extracted and cached
- GPU memory usage < 8GB
- Extraction time < 30 minutes

In [None]:
import librosa
import torch
from transformers import AutoModel, AutoProcessor

class MERTExtractor:
    """
    Extract embeddings from MERT-330M for audio segments.
    
    Uses weighted sum of layers 12-24 following SUPERB/MARBLE protocols.
    Caches embeddings to disk to avoid re-extraction.
    """
    
    def __init__(
        self,
        model_name: str = "m-a-p/MERT-v1-330M",
        device: str = "auto",
        cache_dir: Optional[Path] = None,
        target_sr: int = 24000,  # MERT native sample rate
        use_layers: Tuple[int, int] = (12, 25),  # Layers 12-24 (0-indexed)
    ):
        self.target_sr = target_sr
        self.use_layers = use_layers
        self.cache_dir = cache_dir
        
        # Device selection
        if device == "auto":
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        
        print(f"Loading MERT model: {model_name}")
        print(f"Device: {self.device}")
        
        # Load model and processor
        self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(
            model_name,
            output_hidden_states=True,
            trust_remote_code=True,
        ).to(self.device)
        self.model.eval()
        
        # Get hidden size
        self.hidden_size = self.model.config.hidden_size
        print(f"Hidden size: {self.hidden_size}")
        print(f"Using layers: {use_layers[0]}-{use_layers[1]-1}")
    
    def load_audio(self, audio_path: Path) -> torch.Tensor:
        """Load and resample audio to target sample rate."""
        audio, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
        return torch.from_numpy(audio).float()
    
    @torch.no_grad()
    def extract(self, audio: torch.Tensor) -> torch.Tensor:
        """
        Extract MERT embeddings from audio.
        
        Args:
            audio: Audio tensor [num_samples] at target_sr
        
        Returns:
            Embeddings tensor [num_frames, hidden_size]
        """
        # Process audio
        inputs = self.processor(
            audio.numpy(),
            sampling_rate=self.target_sr,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Forward pass
        outputs = self.model(**inputs)
        
        # Get hidden states from specified layers
        hidden_states = outputs.hidden_states[self.use_layers[0]:self.use_layers[1]]
        
        # Stack and average across layers
        stacked = torch.stack(hidden_states, dim=0)  # [num_layers, B, T, H]
        embeddings = stacked.mean(dim=0).squeeze(0)  # [T, H]
        
        return embeddings.cpu()
    
    def extract_from_file(
        self,
        audio_path: Path,
        use_cache: bool = True,
    ) -> torch.Tensor:
        """
        Extract embeddings from audio file, with optional caching.
        """
        # Check cache
        if use_cache and self.cache_dir is not None:
            cache_path = self.cache_dir / f"{audio_path.stem}.pt"
            if cache_path.exists():
                return torch.load(cache_path)
        
        # Load and extract
        audio = self.load_audio(audio_path)
        embeddings = self.extract(audio)
        
        # Cache
        if use_cache and self.cache_dir is not None:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            torch.save(embeddings, cache_path)
        
        return embeddings


print("MERTExtractor class defined.")

In [None]:
# Extract MERT embeddings for all audio files
print("="*60)
print("MERT FEATURE EXTRACTION")
print("="*60)

# Initialize extractor
extractor = MERTExtractor(
    model_name="m-a-p/MERT-v1-330M",
    cache_dir=MERT_CACHE_DIR,
)

# Get audio files
audio_files = sorted(AUDIO_DIR.glob('*.wav'))
print(f"\nAudio files found: {len(audio_files)}")

# Check cache
cached = len(list(MERT_CACHE_DIR.glob('*.pt')))
print(f"Already cached: {cached}")

# Extract embeddings
successful = 0
failed = []

for audio_path in tqdm(audio_files, desc="Extracting MERT embeddings"):
    try:
        embeddings = extractor.extract_from_file(audio_path, use_cache=True)
        successful += 1
    except Exception as e:
        failed.append((audio_path.stem, str(e)))

print("\n" + "="*60)
print("EXTRACTION COMPLETE")
print("="*60)
print(f"Successful: {successful}")
print(f"Failed: {len(failed)}")

if failed:
    print("\nFailed files:")
    for name, error in failed[:10]:
        print(f"  {name}: {error}")

# Verify cache
cached_files = list(MERT_CACHE_DIR.glob('*.pt'))
print(f"\nCached embeddings: {len(cached_files)}")

# Sample embedding stats
if cached_files:
    sample = torch.load(cached_files[0])
    print(f"Embedding shape: {sample.shape}")
    print(f"Embedding dtype: {sample.dtype}")

---
## Step 5: Dataset and DataLoader

### Goal: Create dataset class that loads MERT embeddings with PercePiano labels

Uses the same k-fold splits as the symbolic baseline for fair comparison.

In [None]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

# 19 PercePiano dimensions (same order as symbolic baseline)
PERCEPIANO_DIMENSIONS = [
    "timing", "articulation_length", "articulation_touch",
    "pedal_amount", "pedal_clarity",
    "timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness",
    "dynamic_range", "tempo", "space", "balance", "drama",
    "mood_valence", "mood_energy", "mood_imagination",
    "sophistication", "interpretation",
]

# Dimension categories for analysis
DIMENSION_CATEGORIES = {
    "timing": ["timing"],
    "articulation": ["articulation_length", "articulation_touch"],
    "pedal": ["pedal_amount", "pedal_clarity"],
    "timbre": ["timbre_variety", "timbre_depth", "timbre_brightness", "timbre_loudness"],
    "dynamics": ["dynamic_range"],
    "tempo_space": ["tempo", "space", "balance", "drama"],
    "emotion": ["mood_valence", "mood_energy", "mood_imagination"],
    "interpretation": ["sophistication", "interpretation"],
}


class AudioPercePianoDataset(Dataset):
    """
    Dataset for audio-based PercePiano evaluation.
    
    Loads pre-extracted MERT embeddings and PercePiano labels.
    Supports k-fold cross-validation with the same splits as symbolic baseline.
    """
    
    def __init__(
        self,
        mert_cache_dir: Path,
        label_file: Path,
        fold_assignments: Optional[Dict] = None,
        fold_id: int = 0,
        mode: str = "train",  # "train", "val", "test"
        max_frames: int = 1000,  # Max MERT frames (truncate if longer)
    ):
        self.mert_cache_dir = Path(mert_cache_dir)
        self.max_frames = max_frames
        
        # Load labels
        with open(label_file) as f:
            all_labels = json.load(f)
        
        # Filter to segments with cached embeddings
        available_keys = {p.stem for p in self.mert_cache_dir.glob('*.pt')}
        self.samples = []
        
        for key, label_values in all_labels.items():
            if key in available_keys:
                # Labels are [19 values, 0] - last value is unused
                labels = torch.tensor(label_values[:19], dtype=torch.float32)
                self.samples.append((key, labels))
        
        print(f"Total samples with embeddings: {len(self.samples)}")
        
        # Apply fold filtering if provided
        if fold_assignments is not None:
            self._apply_fold_filter(fold_assignments, fold_id, mode)
    
    def _apply_fold_filter(self, fold_assignments: Dict, fold_id: int, mode: str):
        """Filter samples based on fold assignment."""
        if mode == "test":
            valid_keys = set(fold_assignments.get("test", []))
        elif mode == "val":
            valid_keys = set(fold_assignments.get(f"fold_{fold_id}", []))
        else:  # train
            valid_keys = set()
            for i in range(4):  # Assuming 4 folds
                if i != fold_id:
                    valid_keys.update(fold_assignments.get(f"fold_{i}", []))
        
        self.samples = [(k, l) for k, l in self.samples if k in valid_keys]
        print(f"{mode} samples (fold {fold_id}): {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        key, labels = self.samples[idx]
        
        # Load cached embeddings
        embed_path = self.mert_cache_dir / f"{key}.pt"
        embeddings = torch.load(embed_path)  # [T, 1024]
        
        # Truncate if too long
        if embeddings.shape[0] > self.max_frames:
            embeddings = embeddings[:self.max_frames]
        
        return {
            "embeddings": embeddings,
            "labels": labels,
            "key": key,
            "length": embeddings.shape[0],
        }


def audio_collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """
    Collate function for variable-length MERT embeddings.
    
    Pads embeddings to max length in batch and creates attention mask.
    """
    embeddings = [item["embeddings"] for item in batch]
    labels = torch.stack([item["labels"] for item in batch])
    lengths = torch.tensor([item["length"] for item in batch])
    keys = [item["key"] for item in batch]
    
    # Pad embeddings
    padded_embeddings = pad_sequence(embeddings, batch_first=True)  # [B, T, H]
    
    # Create attention mask (1 for valid, 0 for padding)
    max_len = padded_embeddings.shape[1]
    attention_mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
    
    return {
        "embeddings": padded_embeddings,
        "attention_mask": attention_mask,
        "labels": labels,
        "lengths": lengths,
        "keys": keys,
    }


print("Dataset classes defined.")
print(f"Dimensions: {len(PERCEPIANO_DIMENSIONS)}")

In [None]:
# Create or load fold assignments
# Use same splits as symbolic baseline for fair comparison

FOLD_FILE = MERT_CACHE_DIR.parent / 'audio_fold_assignments.json'

# Try to load existing symbolic fold assignments
symbolic_fold_file = Path('/tmp/percepiano_vnet_84dim/fold_assignments.json')

if symbolic_fold_file.exists():
    print("Loading fold assignments from symbolic baseline...")
    with open(symbolic_fold_file) as f:
        fold_assignments = json.load(f)
    print(f"Loaded {len(fold_assignments)} fold groups")
elif FOLD_FILE.exists():
    print(f"Loading existing fold assignments from {FOLD_FILE}")
    with open(FOLD_FILE) as f:
        fold_assignments = json.load(f)
else:
    print("Creating new piece-based fold assignments...")
    
    # Group by piece
    from collections import defaultdict
    piece_to_keys = defaultdict(list)
    
    with open(LABEL_FILE) as f:
        labels = json.load(f)
    
    for key in labels.keys():
        # Extract piece name (everything before _Xbars_)
        parts = key.split('_')
        for i, part in enumerate(parts):
            if 'bars' in part:
                piece = '_'.join(parts[:i])
                break
        else:
            piece = key
        piece_to_keys[piece].append(key)
    
    print(f"Found {len(piece_to_keys)} unique pieces")
    
    # Assign pieces to folds (round-robin for simplicity)
    pieces = sorted(piece_to_keys.keys())
    fold_assignments = {"test": [], "fold_0": [], "fold_1": [], "fold_2": [], "fold_3": []}
    
    # First ~15% to test
    test_count = 0
    target_test = len(labels) * 0.15
    test_pieces = []
    
    for piece in pieces:
        if test_count < target_test:
            fold_assignments["test"].extend(piece_to_keys[piece])
            test_count += len(piece_to_keys[piece])
            test_pieces.append(piece)
    
    # Remaining to folds
    remaining_pieces = [p for p in pieces if p not in test_pieces]
    for i, piece in enumerate(remaining_pieces):
        fold_idx = i % 4
        fold_assignments[f"fold_{fold_idx}"].extend(piece_to_keys[piece])
    
    # Save
    with open(FOLD_FILE, 'w') as f:
        json.dump(fold_assignments, f, indent=2)
    print(f"Saved to {FOLD_FILE}")

# Print fold statistics
print("\nFold statistics:")
for fold_name, keys in fold_assignments.items():
    print(f"  {fold_name}: {len(keys)} samples")

In [None]:
# Test dataset
print("Testing dataset...")

test_ds = AudioPercePianoDataset(
    mert_cache_dir=MERT_CACHE_DIR,
    label_file=LABEL_FILE,
    fold_assignments=fold_assignments,
    fold_id=2,
    mode="val",
)

# Test single sample
sample = test_ds[0]
print(f"\nSample:")
print(f"  Key: {sample['key']}")
print(f"  Embeddings shape: {sample['embeddings'].shape}")
print(f"  Labels shape: {sample['labels'].shape}")
print(f"  Labels range: [{sample['labels'].min():.3f}, {sample['labels'].max():.3f}]")

# Test dataloader
test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    collate_fn=audio_collate_fn,
    num_workers=0,
)

batch = next(iter(test_loader))
print(f"\nBatch:")
print(f"  Embeddings: {batch['embeddings'].shape}")
print(f"  Attention mask: {batch['attention_mask'].shape}")
print(f"  Labels: {batch['labels'].shape}")
print(f"  Lengths: {batch['lengths']}")

---
## Step 6: Model Architecture

### Goal: Simple MERT baseline model

We start with the simplest possible architecture to establish a floor:
1. **MERT embeddings** (frozen, pre-extracted)
2. **Mean pooling** across time
3. **Linear layers** for regression

This is intentionally simple. If it works, we can add:
- Attention pooling (Phase 2)
- Bi-LSTM on MERT frames (Phase 2)
- Fine-tuning MERT (Phase 3)

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from sklearn.metrics import r2_score


class MERTBaselineModel(pl.LightningModule):
    """
    Simple MERT baseline for PercePiano evaluation.
    
    Architecture:
        MERT embeddings [B, T, 1024]
            -> Mean pooling [B, 1024]
            -> Linear(1024, 512) + GELU + Dropout
            -> Linear(512, 19) + Sigmoid
    
    This is the simplest possible model. We use this to establish
    a baseline before adding complexity.
    """
    
    def __init__(
        self,
        input_dim: int = 1024,
        hidden_dim: int = 512,
        num_labels: int = 19,
        dropout: float = 0.2,
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-5,
        pooling: str = "mean",  # "mean", "max", "attention"
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_labels = num_labels
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.pooling = pooling
        
        # Optional attention pooling
        if pooling == "attention":
            self.attention = nn.Sequential(
                nn.Linear(input_dim, 256),
                nn.Tanh(),
                nn.Linear(256, 1),
            )
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_labels),
            nn.Sigmoid(),
        )
        
        # Loss function
        self.loss_fn = nn.MSELoss()
        
        # Metrics storage
        self.validation_outputs = []
    
    def forward(
        self,
        embeddings: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            embeddings: [B, T, input_dim]
            attention_mask: [B, T] (1 for valid, 0 for padding)
        
        Returns:
            predictions: [B, num_labels]
        """
        # Pool across time dimension
        if self.pooling == "mean":
            if attention_mask is not None:
                # Masked mean pooling
                mask = attention_mask.unsqueeze(-1).float()  # [B, T, 1]
                pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
            else:
                pooled = embeddings.mean(dim=1)
        
        elif self.pooling == "max":
            if attention_mask is not None:
                embeddings = embeddings.masked_fill(~attention_mask.unsqueeze(-1), float('-inf'))
            pooled = embeddings.max(dim=1).values
        
        elif self.pooling == "attention":
            # Attention weights
            scores = self.attention(embeddings).squeeze(-1)  # [B, T]
            if attention_mask is not None:
                scores = scores.masked_fill(~attention_mask, float('-inf'))
            weights = torch.softmax(scores, dim=-1).unsqueeze(-1)  # [B, T, 1]
            pooled = (embeddings * weights).sum(dim=1)  # [B, H]
        
        else:
            raise ValueError(f"Unknown pooling: {self.pooling}")
        
        # Classify
        predictions = self.classifier(pooled)
        return predictions
    
    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        predictions = self(batch["embeddings"], batch["attention_mask"])
        loss = self.loss_fn(predictions, batch["labels"])
        
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch: Dict, batch_idx: int):
        predictions = self(batch["embeddings"], batch["attention_mask"])
        loss = self.loss_fn(predictions, batch["labels"])
        
        self.log("val_loss", loss, prog_bar=True)
        
        # Store for epoch-end metrics
        self.validation_outputs.append({
            "predictions": predictions.detach().cpu(),
            "labels": batch["labels"].detach().cpu(),
        })
    
    def on_validation_epoch_end(self):
        if not self.validation_outputs:
            return
        
        # Aggregate predictions and labels
        all_preds = torch.cat([x["predictions"] for x in self.validation_outputs])
        all_labels = torch.cat([x["labels"] for x in self.validation_outputs])
        
        # Compute R2
        r2 = r2_score(all_labels.numpy(), all_preds.numpy())
        self.log("val_r2", r2, prog_bar=True)
        
        # Per-dimension R2
        per_dim_r2 = {}
        for i, dim_name in enumerate(PERCEPIANO_DIMENSIONS):
            dim_r2 = r2_score(all_labels[:, i].numpy(), all_preds[:, i].numpy())
            per_dim_r2[dim_name] = dim_r2
            self.log(f"val_r2_{dim_name}", dim_r2)
        
        self.validation_outputs.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=100,
            eta_min=1e-6,
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
            },
        }


print("MERTBaselineModel defined.")
print(f"Pooling options: mean, max, attention")

---
## Step 7: Training Configuration

In [None]:
# Training configuration
torch.set_float32_matmul_precision('medium')

CONFIG = {
    # Model
    'input_dim': 1024,         # MERT-330M hidden size
    'hidden_dim': 512,         # Classifier hidden dim
    'num_labels': 19,          # PercePiano dimensions
    'dropout': 0.2,
    'pooling': 'mean',         # Start simple: mean pooling
    
    # Training
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'batch_size': 16,          # Can be larger since MERT is frozen
    'max_epochs': 100,
    'early_stopping_patience': 15,
    'gradient_clip_val': 1.0,
    
    # Data
    'max_frames': 1000,        # Max MERT frames per segment
    'num_workers': 4,
    
    # Fold
    'fold_id': 2,              # Same as best symbolic fold for comparison
}

print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

---
## Step 8: Training

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

# Set seed
pl.seed_everything(42, workers=True)

# Create datasets
train_ds = AudioPercePianoDataset(
    mert_cache_dir=MERT_CACHE_DIR,
    label_file=LABEL_FILE,
    fold_assignments=fold_assignments,
    fold_id=CONFIG['fold_id'],
    mode="train",
    max_frames=CONFIG['max_frames'],
)

val_ds = AudioPercePianoDataset(
    mert_cache_dir=MERT_CACHE_DIR,
    label_file=LABEL_FILE,
    fold_assignments=fold_assignments,
    fold_id=CONFIG['fold_id'],
    mode="val",
    max_frames=CONFIG['max_frames'],
)

# Create dataloaders
train_loader = DataLoader(
    train_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    collate_fn=audio_collate_fn,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    collate_fn=audio_collate_fn,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Create model
model = MERTBaselineModel(
    input_dim=CONFIG['input_dim'],
    hidden_dim=CONFIG['hidden_dim'],
    num_labels=CONFIG['num_labels'],
    dropout=CONFIG['dropout'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    pooling=CONFIG['pooling'],
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=CHECKPOINT_ROOT / f"fold_{CONFIG['fold_id']}",
        filename="best-{epoch:02d}-{val_r2:.4f}",
        monitor="val_r2",
        mode="max",
        save_top_k=3,
    ),
    EarlyStopping(
        monitor="val_r2",
        mode="max",
        patience=CONFIG['early_stopping_patience'],
        verbose=True,
    ),
    RichProgressBar(),
]

# Logger
logger = TensorBoardLogger(
    save_dir=LOG_ROOT,
    name=f"fold_{CONFIG['fold_id']}",
)

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['max_epochs'],
    callbacks=callbacks,
    logger=logger,
    gradient_clip_val=CONFIG['gradient_clip_val'],
    precision='32',
    accelerator='auto',
    devices=1,
    log_every_n_steps=10,
)

print("="*60)
print("STARTING TRAINING")
print("="*60)

# Train
trainer.fit(model, train_loader, val_loader)

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Best checkpoint: {callbacks[0].best_model_path}")
print(f"Best val R2: {callbacks[0].best_model_score:.4f}")

---
## Step 9: Evaluation and Diagnostics

### Goals:
1. Per-dimension R2 analysis
2. Comparison to symbolic baseline
3. Category-level analysis (timbre, pedal, etc.)
4. Identify where audio excels vs. fails

In [None]:
# Load best checkpoint
best_ckpt = callbacks[0].best_model_path
print(f"Loading best checkpoint: {best_ckpt}")

model = MERTBaselineModel.load_from_checkpoint(best_ckpt)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Run validation predictions
all_predictions = []
all_labels = []
all_keys = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        embeddings = batch["embeddings"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        predictions = model(embeddings, attention_mask)
        
        all_predictions.append(predictions.cpu())
        all_labels.append(batch["labels"])
        all_keys.extend(batch["keys"])

all_predictions = torch.cat(all_predictions).numpy()
all_labels = torch.cat(all_labels).numpy()

print(f"Predictions shape: {all_predictions.shape}")
print(f"Labels shape: {all_labels.shape}")

In [None]:
# Per-dimension R2 analysis
print("="*80)
print("PER-DIMENSION R2 ANALYSIS")
print("="*80)

# Compute per-dimension R2
audio_r2 = {}
for i, dim_name in enumerate(PERCEPIANO_DIMENSIONS):
    r2 = r2_score(all_labels[:, i], all_predictions[:, i])
    audio_r2[dim_name] = r2

# Sort by R2
sorted_dims = sorted(audio_r2.items(), key=lambda x: x[1], reverse=True)

# Display table
table = Table(title="Audio Baseline Per-Dimension R2")
table.add_column("Dimension", style="cyan")
table.add_column("R2", justify="right")
table.add_column("Status", justify="center")

for dim_name, r2 in sorted_dims:
    if r2 >= 0.30:
        status = "[green]Strong[/green]"
    elif r2 >= 0.15:
        status = "[yellow]OK[/yellow]"
    elif r2 >= 0:
        status = "[orange3]Weak[/orange3]"
    else:
        status = "[red]Failed[/red]"
    
    table.add_row(dim_name, f"{r2:+.4f}", status)

console.print(table)

# Overall R2
overall_r2 = r2_score(all_labels, all_predictions)
print(f"\nOverall R2: {overall_r2:+.4f}")

In [None]:
# Category-level analysis
print("\n" + "="*80)
print("CATEGORY-LEVEL ANALYSIS")
print("="*80)

category_r2 = {}
for category, dims in DIMENSION_CATEGORIES.items():
    cat_r2s = [audio_r2[d] for d in dims]
    category_r2[category] = np.mean(cat_r2s)

# Display
table = Table(title="Category Average R2")
table.add_column("Category", style="cyan")
table.add_column("Dimensions", justify="right")
table.add_column("Mean R2", justify="right")
table.add_column("Expected Winner", justify="center")

expected_audio = ["timbre", "pedal", "emotion"]
expected_symbolic = ["timing", "articulation"]

for category in sorted(category_r2.keys(), key=lambda x: category_r2[x], reverse=True):
    r2 = category_r2[category]
    n_dims = len(DIMENSION_CATEGORIES[category])
    
    if category in expected_audio:
        expected = "[green]Audio[/green]"
    elif category in expected_symbolic:
        expected = "[blue]Symbolic[/blue]"
    else:
        expected = "Tie"
    
    table.add_row(category, str(n_dims), f"{r2:+.4f}", expected)

console.print(table)

# Hypothesis validation
print("\nHypothesis Validation:")
timbre_r2 = category_r2["timbre"]
pedal_r2 = category_r2["pedal"]
timing_r2 = category_r2["timing"]

print(f"  Timbre R2: {timbre_r2:+.4f} (target: >= 0.30)")
print(f"  Pedal R2: {pedal_r2:+.4f} (target: >= 0.25)")
print(f"  Timing R2: {timing_r2:+.4f} (expected to be lower than symbolic)")

In [None]:
# Compare to symbolic baseline (if available)
print("\n" + "="*80)
print("COMPARISON TO SYMBOLIC BASELINE")
print("="*80)

# Load symbolic results if available
# These would come from the PercePiano replica training
# For now, use paper-reported values as reference

SYMBOLIC_REFERENCE = {
    "overall": 0.397,  # Bi-LSTM + SA + HAN (paper SOTA)
    "baseline": 0.185,  # Bi-LSTM baseline
    # Per-dimension values would come from our replica training
}

print(f"\nSymbolic baseline (paper): R2 = {SYMBOLIC_REFERENCE['baseline']:.3f}")
print(f"Symbolic SOTA (paper): R2 = {SYMBOLIC_REFERENCE['overall']:.3f}")
print(f"Audio baseline (ours): R2 = {overall_r2:.3f}")

if overall_r2 > SYMBOLIC_REFERENCE['baseline']:
    print(f"\n[SUCCESS] Audio baseline beats symbolic baseline!")
    print(f"  Improvement: +{overall_r2 - SYMBOLIC_REFERENCE['baseline']:.3f}")
else:
    print(f"\n[INFO] Audio baseline below symbolic baseline")
    print(f"  Gap: {overall_r2 - SYMBOLIC_REFERENCE['baseline']:.3f}")

print("\nNote: Per-dimension comparison requires symbolic replica results.")
print("See train_percepiano_replica.ipynb for symbolic per-dimension R2.")

---
## Step 10: Diagnostics and Error Analysis

In [None]:
# Prediction statistics
print("="*80)
print("PREDICTION DIAGNOSTICS")
print("="*80)

print(f"\nPrediction statistics:")
print(f"  Mean: {all_predictions.mean():.4f}")
print(f"  Std: {all_predictions.std():.4f}")
print(f"  Min: {all_predictions.min():.4f}")
print(f"  Max: {all_predictions.max():.4f}")

print(f"\nLabel statistics:")
print(f"  Mean: {all_labels.mean():.4f}")
print(f"  Std: {all_labels.std():.4f}")
print(f"  Min: {all_labels.min():.4f}")
print(f"  Max: {all_labels.max():.4f}")

# Check for prediction collapse
pred_std = all_predictions.std(axis=0)
label_std = all_labels.std(axis=0)

print(f"\nPer-dimension prediction std vs label std:")
collapsed = []
for i, dim_name in enumerate(PERCEPIANO_DIMENSIONS):
    ratio = pred_std[i] / label_std[i] if label_std[i] > 0 else 0
    status = "OK" if ratio > 0.5 else "COLLAPSED"
    if status == "COLLAPSED":
        collapsed.append(dim_name)
        print(f"  {dim_name}: pred_std={pred_std[i]:.4f}, label_std={label_std[i]:.4f}, ratio={ratio:.2f} [{status}]")

if collapsed:
    print(f"\n[WARNING] {len(collapsed)} dimensions show prediction collapse:")
    for dim in collapsed:
        print(f"  - {dim}")
else:
    print(f"\n[OK] No prediction collapse detected")

In [None]:
# Worst performing samples
print("\n" + "="*80)
print("WORST PERFORMING SAMPLES")
print("="*80)

# Compute per-sample MSE
sample_mse = np.mean((all_predictions - all_labels) ** 2, axis=1)
worst_indices = np.argsort(sample_mse)[-10:][::-1]

print("\nTop 10 worst predictions:")
for idx in worst_indices:
    key = all_keys[idx]
    mse = sample_mse[idx]
    print(f"  {key}: MSE={mse:.4f}")

---
## Summary and Next Steps

In [None]:
print("="*80)
print("AUDIO BASELINE SUMMARY")
print("="*80)

print(f"\n1. OVERALL PERFORMANCE")
print(f"   Audio Baseline R2: {overall_r2:+.4f}")
print(f"   Symbolic Baseline (paper): 0.185")
print(f"   Symbolic SOTA (paper): 0.397")

print(f"\n2. HYPOTHESIS VALIDATION")
print(f"   Timbre dimensions: {category_r2['timbre']:+.4f} (expected audio advantage)")
print(f"   Pedal dimensions: {category_r2['pedal']:+.4f} (expected audio advantage)")
print(f"   Timing dimension: {category_r2['timing']:+.4f} (expected symbolic advantage)")

print(f"\n3. STRONGEST DIMENSIONS (Audio)")
for dim, r2 in sorted_dims[:5]:
    print(f"   {dim}: {r2:+.4f}")

print(f"\n4. WEAKEST DIMENSIONS (Audio)")
for dim, r2 in sorted_dims[-5:]:
    print(f"   {dim}: {r2:+.4f}")

print(f"\n5. NEXT STEPS")
print(f"   [ ] Compare per-dimension with symbolic baseline from replica training")
print(f"   [ ] Try attention pooling instead of mean pooling")
print(f"   [ ] Add Bi-LSTM on MERT frames")
print(f"   [ ] Explore multimodal fusion if audio/symbolic show complementary strengths")

print("="*80)

In [None]:
# Save results for comparison
results = {
    "config": CONFIG,
    "overall_r2": float(overall_r2),
    "per_dimension_r2": {k: float(v) for k, v in audio_r2.items()},
    "category_r2": {k: float(v) for k, v in category_r2.items()},
    "best_checkpoint": str(best_ckpt),
}

results_path = CHECKPOINT_ROOT / "audio_baseline_results.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_path}")