# Score Alignment Exploration

This notebook explores aligning student piano performances to rendered MIDI scores
in MuQ embedding space using DTW-based algorithms.

**Experiments:**
- A: DTW Baseline on raw MuQ embeddings
- B: Learned projection MLP with soft-DTW loss
- C: Measure-level alignment (coarser granularity)

**Success metric:** Mean onset error < 30ms (human perception threshold)

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

!curl -fsSL https://rclone.org/install.sh | sudo bash 2>&1 | grep -E "(successfully|already)" || echo "rclone installed"

In [None]:
from pathlib import Path

NOTEBOOK_DIR = Path.cwd()
MODEL_ROOT = Path("/workspace/crescendai/model")
SRC_DIR = MODEL_ROOT / "src"

# Add src to path
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

print(f"Model root: {MODEL_ROOT}")
print(f"Source dir: {SRC_DIR}")

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# Score alignment imports
from score_alignment.config import (
    MUQ_FRAME_RATE,
    ProjectionConfig,
    TrainingConfig,
    ASAP_REPO_URL,
)
from score_alignment.data.asap import (
    parse_asap_metadata,
    load_note_alignments,
    extract_onset_pairs,
    get_performance_key,
)
from score_alignment.data.alignment_dataset import (
    FrameAlignmentDataset,
    MeasureAlignmentDataset,
    frame_alignment_collate_fn,
)
from score_alignment.alignment.dtw import align_embeddings, compute_cost_matrix
from score_alignment.alignment.metrics import (
    onset_error,
    evaluate_dtw_alignment,
    compute_alignment_summary,
)
from score_alignment.models.projection import AlignmentProjectionModel
from score_alignment.training.runner import (
    run_alignment_experiment,
    run_dtw_baseline,
)

DEVICE = "cuda"

print(f"MuQ frame rate: {MUQ_FRAME_RATE} fps")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")

## 2. Data Paths

In [None]:
# All data lives under model/data/
DATA_ROOT = MODEL_ROOT / "data"
ASAP_ROOT = DATA_ROOT / "asap-dataset"
SCORE_CACHE_DIR = DATA_ROOT / "muq_cache" / "scores"
PERF_CACHE_DIR = DATA_ROOT / "muq_cache" / "performances"
CHECKPOINT_DIR = DATA_ROOT / "checkpoints" / "score_alignment"
RESULTS_DIR = DATA_ROOT / "results" / "score_alignment"
LOG_DIR = DATA_ROOT / "logs" / "score_alignment"

# Create directories
for d in [DATA_ROOT, SCORE_CACHE_DIR, PERF_CACHE_DIR, CHECKPOINT_DIR, RESULTS_DIR, LOG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print(f"Data root: {DATA_ROOT}")
print(f"ASAP root: {ASAP_ROOT}")

In [None]:
# Remote storage configuration (Google Drive via rclone)
RCLONE_REMOTE = "gdrive"
RCLONE_BASE = "crescendai_data"

REMOTE_PATHS = {
    "muq_scores": f"{RCLONE_BASE}/score_alignment/muq_cache/scores",
    "muq_performances": f"{RCLONE_BASE}/score_alignment/muq_cache/performances",
    "checkpoints": f"{RCLONE_BASE}/checkpoints/score_alignment",
    "results": f"{RCLONE_BASE}/results/score_alignment",
    "logs": f"{RCLONE_BASE}/logs/score_alignment",
}

def rclone_sync(remote_path: str, local_path: str, direction: str = "download") -> None:
    local_dir = Path(local_path)
    local_dir.mkdir(parents=True, exist_ok=True)

    full_remote = f"{RCLONE_REMOTE}:{remote_path}"

    if direction == "download":
        cmd = ["rclone", "sync", full_remote, str(local_dir), "--progress"]
    else:
        cmd = ["rclone", "sync", str(local_dir), full_remote, "--progress"]

    print(f"Running: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"rclone sync failed: {result.stderr}")
    print(f"Sync complete: {local_dir}")


def rclone_copy_file(remote_path: str, local_path: str) -> None:
    local_file = Path(local_path)
    local_file.parent.mkdir(parents=True, exist_ok=True)

    full_remote = f"{RCLONE_REMOTE}:{remote_path}"

    cmd = ["rclone", "copyto", full_remote, str(local_file), "--progress"]
    print(f"Running: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"rclone copy failed: {result.stderr}")
    print(f"Copied: {local_file}")


def upload_experiment(exp_id: str) -> None:
    """Upload checkpoint dir + results JSON for a given experiment."""
    exp_ckpt_dir = CHECKPOINT_DIR / exp_id
    if exp_ckpt_dir.exists():
        remote_ckpt = f"{REMOTE_PATHS['checkpoints']}/{exp_id}"
        rclone_sync(remote_ckpt, str(exp_ckpt_dir), direction="upload")

    results_file = RESULTS_DIR / f"{exp_id}.json"
    if results_file.exists():
        remote_results = f"{REMOTE_PATHS['results']}/{exp_id}.json"
        full_remote = f"{RCLONE_REMOTE}:{remote_results}"
        cmd = ["rclone", "copyto", str(results_file), full_remote, "--progress"]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Failed to upload results: {result.stderr}")
        print(f"Uploaded results: {exp_id}.json")


# Verify gdrive remote is configured
check = subprocess.run(["rclone", "listremotes"], capture_output=True, text=True)
if check.returncode != 0:
    raise RuntimeError("rclone not installed or not working")
remotes = check.stdout.strip().split("\n")
if f"{RCLONE_REMOTE}:" not in remotes:
    raise RuntimeError(
        f"rclone remote '{RCLONE_REMOTE}' not configured. "
        f"Available remotes: {remotes}. Run 'rclone config' to add it."
    )
print(f"rclone remote '{RCLONE_REMOTE}' verified")
print(f"Remote base: {RCLONE_BASE}")

## 3. Clone ASAP Dataset

In [None]:
import subprocess

if not ASAP_ROOT.exists():
    print(f"Cloning ASAP dataset from {ASAP_REPO_URL}...")
    subprocess.run(["git", "clone", ASAP_REPO_URL, str(ASAP_ROOT)], check=True)
    print("Done!")
else:
    print(f"ASAP dataset already exists at {ASAP_ROOT}")

In [None]:
# Parse ASAP metadata
asap_index = parse_asap_metadata(ASAP_ROOT)

print(f"Total performances: {len(asap_index)}")
print(f"Composers: {len(asap_index.get_composers())}")
print(f"Pieces: {len(asap_index.get_pieces())}")

# Filter to performances with alignments
aligned_perfs = asap_index.filter_with_alignments()
print(f"Performances with alignments: {len(aligned_perfs)}")

# Get multi-performer pieces (for disentanglement later)
multi_performer = asap_index.get_multi_performer_pieces(min_performers=3)
print(f"Pieces with 3+ performers: {len(multi_performer)}")

In [None]:
# Download cached MuQ embeddings from gdrive (for Thunder Compute sessions)
# Only .pt embeddings are needed at runtime -- audio files are not required.
if not SCORE_CACHE_DIR.exists() or len(list(SCORE_CACHE_DIR.glob("*.pt"))) == 0:
    print("Downloading MuQ score embeddings...")
    rclone_sync(REMOTE_PATHS["muq_scores"], str(SCORE_CACHE_DIR))
else:
    print(f"MuQ score cache exists with {len(list(SCORE_CACHE_DIR.glob('*.pt')))} files")

if not PERF_CACHE_DIR.exists() or len(list(PERF_CACHE_DIR.glob("*.pt"))) == 0:
    print("Downloading MuQ performance embeddings...")
    rclone_sync(REMOTE_PATHS["muq_performances"], str(PERF_CACHE_DIR))
else:
    print(f"MuQ perf cache exists with {len(list(PERF_CACHE_DIR.glob('*.pt')))} files")

# Download existing checkpoints + results (for experiment skip logic)
print("\nDownloading existing checkpoints...")
try:
    rclone_sync(REMOTE_PATHS["checkpoints"], str(CHECKPOINT_DIR))
    n_ckpts = len(list(CHECKPOINT_DIR.glob("**/*.ckpt")))
    print(f"Found {n_ckpts} existing checkpoints")
except Exception as e:
    print(f"No existing checkpoints (fine for fresh start): {e}")

print("\nDownloading existing results...")
try:
    rclone_sync(REMOTE_PATHS["results"], str(RESULTS_DIR))
    n_results = len(list(RESULTS_DIR.glob("*.json")))
    print(f"Found {n_results} existing result files")
except Exception as e:
    print(f"No existing results (fine for fresh start): {e}")

print("\nData sync complete!")

## 4. Train/Val Split

In [None]:
from sklearn.model_selection import train_test_split

# Get all performance keys
all_keys = [get_performance_key(p) for p in aligned_perfs]

# Split 80/20
train_keys, val_keys = train_test_split(all_keys, test_size=0.2, random_state=42)

print(f"Train: {len(train_keys)} performances")
print(f"Val: {len(val_keys)} performances")

## 5. Experiment A: DTW Baseline

Run standard DTW on raw MuQ embeddings without any learned projection.

In [None]:
# Run DTW baseline on validation set (results are cached to disk)
baseline_metrics = run_dtw_baseline(
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    keys=val_keys,
    distance_metric="cosine",
    results_dir=RESULTS_DIR,
    results_key="A_dtw_baseline",
)

print("\nDTW Baseline Results:")
print(f"  Mean onset error: {baseline_metrics['weighted_mean_error_ms']:.1f} ms")
print(f"  Within 30ms: {baseline_metrics['weighted_percent_within_threshold']:.1f}%")
print(f"  Num performances: {baseline_metrics['num_performances']}")
print(f"  Total notes evaluated: {baseline_metrics['total_notes']}")

In [None]:
upload_experiment("A_dtw_baseline")

In [None]:
def visualize_alignment(sample_idx=0):
    """Visualize DTW alignment for a single sample."""
    dataset = FrameAlignmentDataset(
        [p for p in aligned_perfs if get_performance_key(p) in val_keys][:10],
        SCORE_CACHE_DIR,
        PERF_CACHE_DIR,
        ASAP_ROOT,
    )

    if len(dataset) == 0:
        print("No samples available for visualization")
        return

    sample = dataset[sample_idx]

    score_emb = sample["score_embeddings"].numpy()
    perf_emb = sample["perf_embeddings"].numpy()

    cost_matrix = compute_cost_matrix(score_emb, perf_emb, metric="cosine")
    path_score, path_perf, cost, _ = align_embeddings(score_emb, perf_emb)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].imshow(cost_matrix, aspect='auto', origin='lower', cmap='viridis')
    axes[0].plot(path_perf, path_score, 'r-', linewidth=2, label='DTW path')
    axes[0].set_xlabel('Performance frames')
    axes[0].set_ylabel('Score frames')
    axes[0].set_title(f'DTW Alignment - {sample["key"]}')
    axes[0].legend()

    gt_score = sample["score_onsets"].numpy()
    gt_perf = sample["perf_onsets"].numpy()

    axes[1].scatter(gt_score, gt_perf, alpha=0.6, label='Ground truth')
    axes[1].plot([0, max(gt_score)], [0, max(gt_perf)], 'k--', alpha=0.3, label='Diagonal')
    axes[1].set_xlabel('Score onset (sec)')
    axes[1].set_ylabel('Performance onset (sec)')
    axes[1].set_title('Note-level Alignment')
    axes[1].legend()

    plt.tight_layout()
    plt.show()

visualize_alignment(0)

## 6. Experiment B: Learned Projection

Train a projection MLP with soft-DTW divergence loss.

In [None]:
projection_config = ProjectionConfig(
    input_dim=1024,
    hidden_dim=512,
    output_dim=256,
    num_layers=3,
    dropout=0.1,
)

training_config = TrainingConfig(
    learning_rate=1e-4,
    weight_decay=1e-5,
    batch_size=4,  # Keep small: soft-DTW divergence allocates T*T matrices per sample
    max_epochs=50,
    patience=10,
    soft_dtw_gamma=1.0,
    num_workers=4,  # Thunder Compute has 4 vCPU
)

print("Projection config:")
print(f"  {projection_config.input_dim} -> {projection_config.hidden_dim} -> {projection_config.output_dim}")
print(f"  Layers: {projection_config.num_layers}")
print(f"  Batch size: {training_config.batch_size}")

In [None]:
exp_b_results = run_alignment_experiment(
    exp_id="B_learned_projection",
    description="Learned projection MLP with soft-DTW loss",
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    train_keys=train_keys,
    val_keys=val_keys,
    projection_config=projection_config,
    training_config=training_config,
    checkpoint_dir=CHECKPOINT_DIR,
    results_dir=RESULTS_DIR,
    log_dir=LOG_DIR,
)

In [None]:
upload_experiment("B_learned_projection")

## 7. Experiment C: Measure-Level Alignment

Pool embeddings by measure for coarser alignment (faster, less memory).

In [None]:
exp_c_results = run_alignment_experiment(
    exp_id="C_measure_level",
    description="Measure-level alignment with pooled embeddings",
    performances=aligned_perfs,
    score_cache_dir=SCORE_CACHE_DIR,
    perf_cache_dir=PERF_CACHE_DIR,
    asap_root=ASAP_ROOT,
    train_keys=train_keys,
    val_keys=val_keys,
    projection_config=projection_config,
    training_config=training_config,
    checkpoint_dir=CHECKPOINT_DIR,
    results_dir=RESULTS_DIR,
    log_dir=LOG_DIR,
    use_measures=True,
)

In [None]:
upload_experiment("C_measure_level")

In [None]:
# Final sync: upload all results and checkpoints to gdrive
print("Final sync to remote storage...")
rclone_sync(REMOTE_PATHS["checkpoints"], str(CHECKPOINT_DIR), direction="upload")
rclone_sync(REMOTE_PATHS["results"], str(RESULTS_DIR), direction="upload")
print("Final sync complete!")

## 8. Results Comparison

In [None]:
import pandas as pd

results_data = [
    {
        "Experiment": "A: DTW Baseline",
        "Mean Error (ms)": baseline_metrics.get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": baseline_metrics.get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": baseline_metrics.get("total_notes", 0),
    },
]

if "metrics" in exp_b_results:
    results_data.append({
        "Experiment": "B: Learned Projection",
        "Mean Error (ms)": exp_b_results["metrics"].get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": exp_b_results["metrics"].get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": exp_b_results["metrics"].get("total_notes", 0),
    })

if "metrics" in exp_c_results:
    results_data.append({
        "Experiment": "C: Measure-Level",
        "Mean Error (ms)": exp_c_results["metrics"].get("weighted_mean_error_ms", np.nan),
        "Within 30ms (%)": exp_c_results["metrics"].get("weighted_percent_within_threshold", np.nan),
        "Notes Evaluated": exp_c_results["metrics"].get("total_notes", 0),
    })

results_df = pd.DataFrame(results_data)
print("\nResults Comparison:")
print(results_df.to_string(index=False))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

experiments = results_df["Experiment"].values
mean_errors = results_df["Mean Error (ms)"].values
within_30ms = results_df["Within 30ms (%)"].values

colors = ['#2ecc71' if e < 30 else '#e74c3c' for e in mean_errors]
axes[0].bar(experiments, mean_errors, color=colors)
axes[0].axhline(y=30, color='red', linestyle='--', label='30ms threshold')
axes[0].set_ylabel('Mean Onset Error (ms)')
axes[0].set_title('Mean Onset Error by Experiment')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=15)

axes[1].bar(experiments, within_30ms, color='steelblue')
axes[1].set_ylabel('Onsets within 30ms (%)')
axes[1].set_title('Alignment Accuracy by Experiment')
axes[1].tick_params(axis='x', rotation=15)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'alignment_comparison.png', dpi=150)
plt.show()

## 9. Recommendations

In [None]:
best_idx = results_df["Mean Error (ms)"].idxmin()
best_exp = results_df.loc[best_idx]

print("Best performing experiment:")
print(f"  {best_exp['Experiment']}")
print(f"  Mean error: {best_exp['Mean Error (ms)']:.1f} ms")
print(f"  Within 30ms: {best_exp['Within 30ms (%)']:.1f}%")

if best_exp["Mean Error (ms)"] < 30:
    print("\nSuccess: Mean error below 30ms human perception threshold")
else:
    print("\nNeeds improvement: Mean error above 30ms threshold")
    print("Consider:")
    print("  - Larger projection network")
    print("  - Different soft-DTW gamma")
    print("  - More training data")
    print("  - Sakoe-Chiba band constraint for DTW")