# Score Alignment Exploration

This notebook explores aligning student piano performances to rendered MIDI scores
in MuQ embedding space using DTW-based algorithms.

**Experiments:**
- A: DTW Baseline on raw MuQ embeddings
- B: Learned projection MLP with soft-DTW loss
- C: Measure-level alignment (coarser granularity)

**Success metric:** Mean onset error < 30ms (human perception threshold)

## 1. Setup

In [1]:
import sys
from pathlib import Path

# Notebook is at: model/notebooks/score_alignment/
# src is at: model/src/
NOTEBOOK_DIR = Path.cwd()
MODEL_ROOT = NOTEBOOK_DIR.parent.parent
SRC_DIR = MODEL_ROOT / "src"

# Add src to path
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

print(f"Model root: {MODEL_ROOT}")
print(f"Source dir: {SRC_DIR}")

Model root: /Users/jdhiman/Documents/crescendai/model
Source dir: /Users/jdhiman/Documents/crescendai/model/src


In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# Score alignment imports
from score_alignment.config import (
    MUQ_FRAME_RATE,
    ProjectionConfig,
    TrainingConfig,
    ASAP_REPO_URL,
)
from score_alignment.data.asap import (
    parse_asap_metadata,
    load_note_alignments,
    extract_onset_pairs,
    get_performance_key,
)
from score_alignment.data.alignment_dataset import (
    FrameAlignmentDataset,
    MeasureAlignmentDataset,
    frame_alignment_collate_fn,
)
from score_alignment.alignment.dtw import align_embeddings, compute_cost_matrix
from score_alignment.alignment.metrics import (
    onset_error,
    evaluate_dtw_alignment,
    compute_alignment_summary,
)
from score_alignment.models.projection import AlignmentProjectionModel
from score_alignment.training.runner import (
    run_alignment_experiment,
    run_dtw_baseline,
)

# Check device
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"MuQ frame rate: {MUQ_FRAME_RATE} fps")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")

MuQ frame rate: 75 fps
PyTorch version: 2.10.0
Device: mps


## 2. Data Paths

In [3]:
# All data lives under model/data/
DATA_ROOT = MODEL_ROOT / "data"
ASAP_ROOT = DATA_ROOT / "asap-dataset"
SCORE_CACHE_DIR = DATA_ROOT / "muq_cache" / "scores"
PERF_CACHE_DIR = DATA_ROOT / "muq_cache" / "performances"
CHECKPOINT_DIR = DATA_ROOT / "checkpoints" / "score_alignment"
RESULTS_DIR = DATA_ROOT / "results" / "score_alignment"
LOG_DIR = DATA_ROOT / "logs" / "score_alignment"

# Create directories
for d in [DATA_ROOT, SCORE_CACHE_DIR, PERF_CACHE_DIR, CHECKPOINT_DIR, RESULTS_DIR, LOG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print(f"Data root: {DATA_ROOT}")
print(f"ASAP root: {ASAP_ROOT}")

Data root: /Users/jdhiman/Documents/crescendai/model/data
ASAP root: /Users/jdhiman/Documents/crescendai/model/data/asap-dataset


## 3. Clone ASAP Dataset

In [4]:
import subprocess

if not ASAP_ROOT.exists():
    print(f"Cloning ASAP dataset from {ASAP_REPO_URL}...")
    subprocess.run(["git", "clone", ASAP_REPO_URL, str(ASAP_ROOT)], check=True)
    print("Done!")
else:
    print(f"ASAP dataset already exists at {ASAP_ROOT}")

ASAP dataset already exists at /Users/jdhiman/Documents/crescendai/model/data/asap-dataset


In [5]:
# Parse ASAP metadata
asap_index = parse_asap_metadata(ASAP_ROOT)

print(f"Total performances: {len(asap_index)}")
print(f"Composers: {len(asap_index.get_composers())}")
print(f"Pieces: {len(asap_index.get_pieces())}")

# Filter to performances with alignments
aligned_perfs = asap_index.filter_with_alignments()
print(f"Performances with alignments: {len(aligned_perfs)}")

# Get multi-performer pieces (for disentanglement later)
multi_performer = asap_index.get_multi_performer_pieces(min_performers=3)
print(f"Pieces with 3+ performers: {len(multi_performer)}")

Total performances: 1066
Composers: 16
Pieces: 242
Performances with alignments: 1063
Pieces with 3+ performers: 126


## 4. Embedding Extraction

Extract MuQ embeddings for scores (rendered MIDI) and performances.

**Note:** MuQ extraction works best with GPU. On M4 Mac, MPS can be used but may be slower.

In [6]:
# Check existing cache
score_cached = list(SCORE_CACHE_DIR.glob("*.pt"))
perf_cached = list(PERF_CACHE_DIR.glob("*.pt"))

print(f"Cached score embeddings: {len(score_cached)}")
print(f"Cached performance embeddings: {len(perf_cached)}")

Cached score embeddings: 0
Cached performance embeddings: 0


In [12]:
!uv add nnAudio

[2K[2mResolved [1m218 packages[0m [2min 1.06s[0m[0m                                       [0m
[2K   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file:///Users/jdhiman/Documents/crescendai/model[0m[0m [2mPreparing packages...[0m (0/0)                                                   
[2K[1A   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file:///Users/jdhiman/Documents/crescendai/model[0m
[37m⠙[0m [2mPreparing packages...[0m (0/2)
[2K[2A   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file:///Users/jdhiman/Documents/crescendai/model[0m
[37m⠙[0m [2mPreparing packages...[0m (0/2)
[2K[2A   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file:///Users/jdhiman/Documents/crescendai/model[0m
[37m⠙[0m [2mPreparing packages...[0m (0/2)
[2K[2A   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file:///Users/jdhiman/Documents/crescendai/model[0m
[37m⠙[0m [2mPreparing packages...[0m (0/2)
[2K[2A   [36m[1mBuilding[0m[39m piano-eval-mvp[2m @ file

In [13]:
def extract_embeddings_if_needed(performances, cache_dir, asap_root, is_score=False):
    """Extract MuQ embeddings for audio files not yet cached."""
    from audio_experiments.extractors.muq import MuQExtractor

    cache_dir = Path(cache_dir)
    cached = {p.stem for p in cache_dir.glob("*.pt")}

    to_extract = []
    for perf in performances:
        key = get_performance_key(perf)
        if is_score:
            if perf.midi_score_path:
                key = perf.midi_score_path.stem

        if key not in cached:
            audio_path = perf.audio_path
            if audio_path:
                full_path = asap_root / audio_path
                if full_path.exists():
                    to_extract.append((key, full_path))

    if not to_extract:
        print(f"All embeddings already cached in {cache_dir}")
        return 0

    print(f"Extracting {len(to_extract)} MuQ embeddings...")
    extractor = MuQExtractor(cache_dir=cache_dir)

    for key, audio_path in to_extract:
        extractor.extract_from_file(audio_path)

    return len(to_extract)

extract_embeddings_if_needed(aligned_perfs, PERF_CACHE_DIR, ASAP_ROOT, is_score=False)

All embeddings already cached in /Users/jdhiman/Documents/crescendai/model/data/muq_cache/performances


0

## 5. Train/Val Split

In [14]:
from sklearn.model_selection import train_test_split

# Get all performance keys
all_keys = [get_performance_key(p) for p in aligned_perfs]

# Split 80/20
train_keys, val_keys = train_test_split(all_keys, test_size=0.2, random_state=42)

print(f"Train: {len(train_keys)} performances")
print(f"Val: {len(val_keys)} performances")

Train: 850 performances
Val: 213 performances


## 6. Experiment A: DTW Baseline

Run standard DTW on raw MuQ embeddings without any learned projection.

In [15]:
# Run DTW baseline on validation set
baseline_metrics = run_dtw_baseline(
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    keys=val_keys,
    distance_metric="cosine",
)

print("\nDTW Baseline Results:")
print(f"  Mean onset error: {baseline_metrics['weighted_mean_error_ms']:.1f} ms")
print(f"  Within 30ms: {baseline_metrics['weighted_percent_within_threshold']:.1f}%")
print(f"  Num performances: {baseline_metrics['num_performances']}")
print(f"  Total notes evaluated: {baseline_metrics['total_notes']}")

Running DTW baseline on 0 samples...

DTW Baseline Results:
  Mean onset error: 0.0 ms
  Within 30ms: 0.0%
  Num performances: 0
  Total notes evaluated: 0


In [None]:
def visualize_alignment(sample_idx=0):
    """Visualize DTW alignment for a single sample."""
    dataset = FrameAlignmentDataset(
        [p for p in aligned_perfs if get_performance_key(p) in val_keys][:10],
        SCORE_CACHE_DIR,
        PERF_CACHE_DIR,
        ASAP_ROOT,
    )
    
    if len(dataset) == 0:
        print("No samples available for visualization")
        return
    
    sample = dataset[sample_idx]
    
    score_emb = sample["score_embeddings"].numpy()
    perf_emb = sample["perf_embeddings"].numpy()
    
    cost_matrix = compute_cost_matrix(score_emb, perf_emb, metric="cosine")
    path_score, path_perf, cost, _ = align_embeddings(score_emb, perf_emb)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].imshow(cost_matrix, aspect='auto', origin='lower', cmap='viridis')
    axes[0].plot(path_perf, path_score, 'r-', linewidth=2, label='DTW path')
    axes[0].set_xlabel('Performance frames')
    axes[0].set_ylabel('Score frames')
    axes[0].set_title(f'DTW Alignment - {sample["key"]}')
    axes[0].legend()
    
    gt_score = sample["score_onsets"].numpy()
    gt_perf = sample["perf_onsets"].numpy()
    
    axes[1].scatter(gt_score, gt_perf, alpha=0.6, label='Ground truth')
    axes[1].plot([0, max(gt_score)], [0, max(gt_perf)], 'k--', alpha=0.3, label='Diagonal')
    axes[1].set_xlabel('Score onset (sec)')
    axes[1].set_ylabel('Performance onset (sec)')
    axes[1].set_title('Note-level Alignment')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()

# visualize_alignment(0)

## 7. Experiment B: Learned Projection

Train a projection MLP with soft-DTW divergence loss.

In [None]:
projection_config = ProjectionConfig(
    input_dim=1024,
    hidden_dim=512,
    output_dim=256,
    num_layers=3,
    dropout=0.1,
)

training_config = TrainingConfig(
    learning_rate=1e-4,
    weight_decay=1e-5,
    batch_size=8,
    max_epochs=50,
    patience=10,
    soft_dtw_gamma=1.0,
    num_workers=0,  # Use 0 on Mac to avoid multiprocessing issues
)

print("Projection config:")
print(f"  {projection_config.input_dim} -> {projection_config.hidden_dim} -> {projection_config.output_dim}")
print(f"  Layers: {projection_config.num_layers}")

In [None]:
exp_b_results = run_alignment_experiment(
    exp_id="B_learned_projection",
    description="Learned projection MLP with soft-DTW loss",
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    train_keys=train_keys,
    val_keys=val_keys,
    projection_config=projection_config,
    training_config=training_config,
    checkpoint_dir=CHECKPOINT_DIR,
    results_dir=RESULTS_DIR,
    log_dir=LOG_DIR,
)

## 8. Experiment C: Measure-Level Alignment

Pool embeddings by measure for coarser alignment (faster, less memory).

In [None]:
exp_c_results = run_alignment_experiment(
    exp_id="C_measure_level",
    description="Measure-level alignment with pooled embeddings",
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    train_keys=train_keys,
    val_keys=val_keys,
    projection_config=projection_config,
    training_config=training_config,
    checkpoint_dir=CHECKPOINT_DIR,
    results_dir=RESULTS_DIR,
    log_dir=LOG_DIR,
    use_measures=True,
)

## 9. Results Comparison

In [None]:
import pandas as pd

results_data = [
    {
        "Experiment": "A: DTW Baseline",
        "Mean Error (ms)": baseline_metrics.get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": baseline_metrics.get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": baseline_metrics.get("total_notes", 0),
    },
]

if "metrics" in exp_b_results:
    results_data.append({
        "Experiment": "B: Learned Projection",
        "Mean Error (ms)": exp_b_results["metrics"].get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": exp_b_results["metrics"].get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": exp_b_results["metrics"].get("total_notes", 0),
    })

if "metrics" in exp_c_results:
    results_data.append({
        "Experiment": "C: Measure-Level",
        "Mean Error (ms)": exp_c_results["metrics"].get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": exp_c_results["metrics"].get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": exp_c_results["metrics"].get("total_notes", 0),
    })

results_df = pd.DataFrame(results_data)
print("\nResults Comparison:")
print(results_df.to_string(index=False))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

experiments = results_df["Experiment"].values
mean_errors = results_df["Mean Error (ms)"].values
within_30ms = results_df["Within 30ms (%)"].values

colors = ['#2ecc71' if e < 30 else '#e74c3c' for e in mean_errors]
axes[0].bar(experiments, mean_errors, color=colors)
axes[0].axhline(y=30, color='red', linestyle='--', label='30ms threshold')
axes[0].set_ylabel('Mean Onset Error (ms)')
axes[0].set_title('Mean Onset Error by Experiment')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=15)

axes[1].bar(experiments, within_30ms, color='steelblue')
axes[1].set_ylabel('Onsets within 30ms (%)')
axes[1].set_title('Alignment Accuracy by Experiment')
axes[1].tick_params(axis='x', rotation=15)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'alignment_comparison.png', dpi=150)
plt.show()

## 10. Recommendations

In [None]:
best_idx = results_df["Mean Error (ms)"].idxmin()
best_exp = results_df.loc[best_idx]

print("Best performing experiment:")
print(f"  {best_exp['Experiment']}")
print(f"  Mean error: {best_exp['Mean Error (ms)']:.1f} ms")
print(f"  Within 30ms: {best_exp['Within 30ms (%)']:.1f}%")

if best_exp["Mean Error (ms)"] < 30:
    print("\nSuccess: Mean error below 30ms human perception threshold")
else:
    print("\nNeeds improvement: Mean error above 30ms threshold")
    print("Consider:")
    print("  - Larger projection network")
    print("  - Different soft-DTW gamma")
    print("  - More training data")
    print("  - Sakoe-Chiba band constraint for DTW")