# Disentanglement Experiments: Separating Piece from Performer

This notebook implements three approaches to disentangle piece characteristics from performer expression in piano performance evaluation.

**Goal**: Improve pairwise ranking accuracy for same-piece comparisons
- Baseline: ~50% (random)
- Current model intra-piece std: 0.020

**Success Metric**: Pairwise ranking accuracy significantly above 50%, higher intra-piece prediction variance

## Approaches

1. **Approach A**: Contrastive Pairwise Ranking (InfoNCE + margin ranking)
2. **Approach B**: Siamese Dimension-Specific Ranking (per-dimension heads)
3. **Approach C**: Disentangled Dual-Encoder (adversarial piece classification)

## References

- Temperature 0.07: [Contrastive Learning Blog](https://lilianweng.github.io/posts/2021-05-31-contrastive/)
- Projection dim 256: [SimCLR](https://arxiv.org/abs/2002.05709)
- GRL: [DANN Paper](https://jmlr.org/papers/volume17/15-239/15-239.pdf)
- Pairwise ranking: [DirectRanker](https://arxiv.org/abs/1909.02768)

## 1. Setup & Configuration

In [None]:
# Thunder Compute Setup - Install rclone if needed
import os
import subprocess

# Install rclone if not present
if not os.path.exists('/usr/bin/rclone'):
    print("Installing rclone...")
    subprocess.run(['curl', 'https://rclone.org/install.sh', '-o', '/tmp/install_rclone.sh'], check=True)
    subprocess.run(['sudo', 'bash', '/tmp/install_rclone.sh'], check=True)

# Check rclone version
result = subprocess.run(['rclone', 'version'], capture_output=True, text=True)
print(f"rclone: {result.stdout.split(chr(10))[0]}")

In [None]:
# Core imports
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime

import torch
import pytorch_lightning as pl
pl.seed_everything(42, workers=True)

# Audio experiments (Paper 1)
from audio_experiments import PERCEPIANO_DIMENSIONS, BASE_CONFIG, SEED

# Disentanglement (Paper 2)
from disentanglement import (
    # Models
    ContrastivePairwiseRankingModel,
    SiameseDimensionRankingModel,
    DisentangledDualEncoderModel,
    ContrastiveDisentangledModel,
    SiameseDisentangledModel,
    FullCombinedModel,
    TripletRankingModel,
    # Data
    build_multi_performer_pieces,
    create_piece_stratified_folds,
    PairwiseRankingDataset,
    HardPairRankingDataset,
    DisentanglementDataset,
    TripletRankingDataset,
    pairwise_collate_fn,
    disentanglement_collate_fn,
    triplet_collate_fn,
    compute_pairwise_statistics,
    # Training
    run_pairwise_experiment,
    run_disentanglement_experiment,
    run_triplet_experiment,
    run_dimension_group_experiment,
    compute_pairwise_metrics,
    compute_intra_piece_std,
    evaluate_disentanglement,
    compute_regression_pairwise_accuracy,
)

print("Imports successful!")

In [None]:
# Configuration with literature-backed hyperparameters
# Approach A: Contrastive Pairwise Ranking
APPROACH_A_CONFIG = {
    # Architecture
    "input_dim": 1024,
    "hidden_dim": 512,
    "projection_dim": 256,  # SimCLR recommendation
    "num_labels": 19,
    "dropout": 0.2,
    "pooling": "attention",

    # Contrastive learning
    "temperature": 0.07,  # InfoNCE best practice
    "lambda_contrastive": 0.3,

    # Ranking
    "margin": 0.2,
    "ambiguous_threshold": 0.05,
    "label_smoothing": 0.0,

    # Training
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "gradient_clip_val": 1.0,
    "batch_size": 32,
    "max_epochs": 100,
    "patience": 15,
    "max_frames": 1000,
    "n_folds": 4,
    "num_workers": 2,
    "seed": 42,
}

# Approach B: Siamese Dimension-Specific Ranking
APPROACH_B_CONFIG = {
    **APPROACH_A_CONFIG,
    "comparison_type": "concat_diff",  # [z_a; z_b; z_a-z_b; z_a*z_b]
    "margin": 0.3,
    "label_smoothing": 0.05,  # Slight smoothing helps
}

# Approach C: Disentangled Dual-Encoder
APPROACH_C_CONFIG = {
    **APPROACH_A_CONFIG,
    "lambda_adversarial": 0.5,
    "grl_schedule": "linear",  # DANN paper schedule
    "num_pieces": 206,  # Will be updated dynamically
}

print("Configurations defined.")

In [None]:
# Path configuration for Thunder Compute
import os
import subprocess
from pathlib import Path

# Thunder Compute paths (fast local SSD)
PATHS = {
    "muq_cache": Path("/workspace/data/cache/muq_embeddings"),
    "mert_cache": Path("/workspace/data/cache/mert_layer12"),
    "labels_file": Path("/workspace/data/cache/percepiano_labels.json"),
    "fold_assignments": Path("/workspace/data/cache/fold_assignments.json"),
    "checkpoints": Path("/workspace/checkpoints/disentanglement"),
    "results": Path("/workspace/results/disentanglement"),
    "logs": Path("/workspace/logs/disentanglement"),
}

# Remote paths for rclone (relative to RCLONE_BASE_PATH)
REMOTE_PATHS = {
    "muq_cache": "cache/muq_embeddings",
    "mert_cache": "cache/mert_layer12",
    "labels_file": "cache/percepiano_labels.json",
    "fold_assignments": "cache/fold_assignments.json",
    "checkpoints": "checkpoints/disentanglement",
    "results": "results/disentanglement",
}

# rclone configuration
RCLONE_REMOTE = "gdrive"  # Name of your rclone remote
RCLONE_BASE_PATH = "crescendai/model/data"  # Base path on remote


def rclone_sync(remote_path: str, local_path: str, direction: str = "download") -> None:
    """Sync a directory between remote and local using rclone."""
    local_dir = Path(local_path)
    local_dir.mkdir(parents=True, exist_ok=True)

    full_remote = f"{RCLONE_REMOTE}:{RCLONE_BASE_PATH}/{remote_path}"

    if direction == "download":
        cmd = ["rclone", "sync", full_remote, str(local_dir), "--progress"]
    else:
        cmd = ["rclone", "sync", str(local_dir), full_remote, "--progress"]

    print(f"Running: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error: {result.stderr}")
        raise RuntimeError(f"rclone sync failed: {result.stderr}")
    print(f"Sync complete: {local_dir}")


def rclone_copy_file(remote_path: str, local_path: str) -> None:
    """Copy a single file from remote to local using rclone."""
    local_file = Path(local_path)
    local_file.parent.mkdir(parents=True, exist_ok=True)

    full_remote = f"{RCLONE_REMOTE}:{RCLONE_BASE_PATH}/{remote_path}"

    cmd = ["rclone", "copyto", full_remote, str(local_file), "--progress"]
    print(f"Running: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error: {result.stderr}")
        raise RuntimeError(f"rclone copy failed: {result.stderr}")
    print(f"Copied: {local_file}")


def upload_experiment(exp_id: str) -> None:
    """Upload a specific experiment's checkpoint and results to remote."""
    # Upload checkpoint folder for this experiment
    exp_ckpt_dir = PATHS["checkpoints"] / exp_id
    if exp_ckpt_dir.exists():
        remote_ckpt = f"{REMOTE_PATHS['checkpoints']}/{exp_id}"
        rclone_sync(remote_ckpt, str(exp_ckpt_dir), direction="upload")
    # Upload results JSON for this experiment
    results_file = PATHS["results"] / f"{exp_id}.json"
    if results_file.exists():
        remote_results = f"{REMOTE_PATHS['results']}/{exp_id}.json"
        full_remote = f"{RCLONE_REMOTE}:{RCLONE_BASE_PATH}/{remote_results}"
        cmd = ["rclone", "copyto", str(results_file), full_remote, "--progress"]
        subprocess.run(cmd, check=True)
        print(f"Uploaded results: {exp_id}.json")


def upload_fold_checkpoint(exp_id: str, fold: int) -> None:
    """Upload a single fold checkpoint to remote storage."""
    ckpt_file = PATHS["checkpoints"] / exp_id / f"fold{fold}_best.ckpt"
    if ckpt_file.exists():
        remote_path = f"{REMOTE_PATHS['checkpoints']}/{exp_id}/fold{fold}_best.ckpt"
        full_remote = f"{RCLONE_REMOTE}:{RCLONE_BASE_PATH}/{remote_path}"
        cmd = ["rclone", "copyto", str(ckpt_file), full_remote, "--progress"]
        print(f"Uploading fold {fold} checkpoint...")
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Warning: Failed to upload fold checkpoint: {result.stderr}")
        else:
            print(f"Uploaded: {ckpt_file.name}")


# Create directories
for name, path in PATHS.items():
    if 'cache' not in name and 'labels' not in name and 'fold' not in name:
        path.mkdir(parents=True, exist_ok=True)

print("Thunder Compute paths configured:")
for name, path in PATHS.items():
    exists = path.exists()
    print(f"  {name}: {path} ({'exists' if exists else 'MISSING'})")

In [None]:
# Download data and existing experiments from remote storage
print("Syncing data from remote storage...")

# Download MuQ embeddings (required for training)
if not PATHS["muq_cache"].exists() or len(list(PATHS["muq_cache"].glob("*.pt"))) == 0:
    print("\nDownloading MuQ embeddings...")
    rclone_sync(REMOTE_PATHS["muq_cache"], str(PATHS["muq_cache"]))
else:
    print(f"MuQ cache exists with {len(list(PATHS['muq_cache'].glob('*.pt')))} files")

# Download labels and fold assignments
if not PATHS["labels_file"].exists():
    print("\nDownloading labels...")
    rclone_copy_file(REMOTE_PATHS["labels_file"], str(PATHS["labels_file"]))

if not PATHS["fold_assignments"].exists():
    print("\nDownloading fold assignments...")
    rclone_copy_file(REMOTE_PATHS["fold_assignments"], str(PATHS["fold_assignments"]))

# Download existing checkpoints to resume training
print("\nDownloading existing checkpoints...")
rclone_sync(REMOTE_PATHS["checkpoints"], str(PATHS["checkpoints"]))
n_ckpts = len(list(PATHS["checkpoints"].glob("**/*.ckpt")))
print(f"Found {n_ckpts} existing checkpoints")

# Download existing results
print("\nDownloading existing results...")
rclone_sync(REMOTE_PATHS["results"], str(PATHS["results"]))
n_results = len(list(PATHS["results"].glob("*.json")))
print(f"Found {n_results} existing result files")

# List completed experiments
completed = []
for result_file in PATHS["results"].glob("*.json"):
    exp_id = result_file.stem
    exp_ckpts = list((PATHS["checkpoints"] / exp_id).glob("fold*_best.ckpt")) if (PATHS["checkpoints"] / exp_id).exists() else []
    if len(exp_ckpts) == 4:
        completed.append(exp_id)

if completed:
    print(f"\nCompleted experiments ({len(completed)}):")
    for exp_id in sorted(completed):
        print(f"  - {exp_id}")
else:
    print("\nNo completed experiments found. Starting fresh.")

print("\nData sync complete!")

In [None]:
# Load PercePiano labels
with open(PATHS["labels_file"]) as f:
    labels = json.load(f)

print(f"Loaded {len(labels)} recordings")

# Show sample
sample_key = list(labels.keys())[0]
print(f"\nSample: {sample_key}")
print(f"Labels: {labels[sample_key][:5]}...")

In [None]:
# Load original fold assignments
with open(PATHS["fold_assignments"]) as f:
    original_folds = json.load(f)

print("Original fold sizes:")
for k, v in original_folds.items():
    print(f"  {k}: {len(v)} samples")

In [None]:
# Build multi-performer piece mapping
MULTI_PERFORMER_PIECES = build_multi_performer_pieces(
    labels, original_folds, min_performers=2
)

# Statistics
stats = compute_pairwise_statistics(MULTI_PERFORMER_PIECES, labels)

print(f"Multi-performer pieces:")
print(f"  Pieces: {stats['n_pieces']}")
print(f"  Recordings: {stats['n_recordings']}")
print(f"  Possible pairs: {stats['n_possible_pairs']}")
print(f"\nScore differences:")
print(f"  Mean: {stats['mean_diff']:.3f}")
print(f"  Std: {stats['std_diff']:.3f}")
print(f"  Range: [{stats['min_diff']:.3f}, {stats['max_diff']:.3f}]")

In [None]:
# Create piece-stratified folds (no piece leakage)
PIECE_STRATIFIED_FOLDS = create_piece_stratified_folds(
    MULTI_PERFORMER_PIECES, n_folds=4, seed=42
)

print("Piece-stratified fold sizes:")
for k, v in PIECE_STRATIFIED_FOLDS.items():
    print(f"  {k}: {len(v)} samples")

# Verify no piece leakage
def get_pieces_in_fold(fold_keys):
    pieces = set()
    for key in fold_keys:
        for pid, keys in MULTI_PERFORMER_PIECES.items():
            if key in keys:
                pieces.add(pid)
    return pieces

fold_pieces = [get_pieces_in_fold(PIECE_STRATIFIED_FOLDS[f"fold_{i}"]) for i in range(4)]
for i in range(4):
    for j in range(i+1, 4):
        overlap = fold_pieces[i] & fold_pieces[j]
        if overlap:
            print(f"WARNING: Piece leakage between fold {i} and {j}: {len(overlap)} pieces")
        else:
            print(f"No leakage between fold {i} and {j}")

In [None]:
# Test dataset creation
from disentanglement.data import get_fold_piece_mapping

# Get fold 0 for validation
val_piece_map, val_keys = get_fold_piece_mapping(
    MULTI_PERFORMER_PIECES, PIECE_STRATIFIED_FOLDS, fold_id=0, mode="val"
)
train_piece_map, train_keys = get_fold_piece_mapping(
    MULTI_PERFORMER_PIECES, PIECE_STRATIFIED_FOLDS, fold_id=0, mode="train"
)

# Create a test dataset
test_ds = PairwiseRankingDataset(
    PATHS["muq_cache"],
    labels,
    train_piece_map,
    train_keys,
    max_frames=100,  # Small for testing
)

print(f"Train dataset: {len(test_ds)} pairs")
print(f"Num pieces: {test_ds.get_num_pieces()}")

# Test a sample
sample = test_ds[0]
print(f"\nSample keys:")
for k, v in sample.items():
    if isinstance(v, torch.Tensor):
        print(f"  {k}: {v.shape}")
    else:
        print(f"  {k}: {v}")

## 3. Baseline Establishment

In [None]:
# E1b: Random baseline (expected 50%)
print("E1b: Random Baseline")
print("="*50)

# Generate random predictions for all pairs
n_pairs = stats['n_possible_pairs']
random_acc = 0.5  # Expected accuracy for random predictions

# Monte Carlo estimate with variance
n_simulations = 1000
random_accs = []
for _ in range(n_simulations):
    # Simulate random rankings
    random_preds = np.random.randint(0, 2, size=(n_pairs, 19))
    # With uniform distribution, ~50% should match
    random_accs.append(0.5 + np.random.normal(0, 0.5/np.sqrt(n_pairs*19)))

print(f"Random baseline accuracy: {np.mean(random_accs):.4f} +/- {np.std(random_accs):.4f}")
print(f"95% CI: [{np.percentile(random_accs, 2.5):.4f}, {np.percentile(random_accs, 97.5):.4f}]")

In [None]:
# E1a: Current MuQ model's pairwise ranking accuracy (derived from regression)
print("E1a: Current MuQ Model Pairwise Accuracy")
print("="*50)

# Import the baseline model and evaluation function
from audio_experiments.models.muq_models import MuQBaseModel
from disentanglement.training import compute_regression_pairwise_accuracy

# Download MuQ checkpoint from Paper 1 if not present
muq_remote = "checkpoints/muq"
muq_local = PATHS["checkpoints"].parent / "muq"
muq_checkpoint = muq_local / "fold0_best.ckpt"

if not muq_checkpoint.exists():
    print("Downloading MuQ checkpoint from remote storage...")
    muq_local.mkdir(parents=True, exist_ok=True)
    rclone_sync(muq_remote, str(muq_local))

if muq_checkpoint.exists():
    print(f"Found MuQ checkpoint: {muq_checkpoint}")

    # Load the regression model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = MuQBaseModel.load_from_checkpoint(muq_checkpoint)

    # Compute pairwise accuracy
    print("\nEvaluating pairwise ranking accuracy...")
    baseline_results = compute_regression_pairwise_accuracy(
        model=model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        device=device,
        ambiguous_threshold=0.05,
    )

    print(f"\nE1a Baseline Results:")
    print(f"  Overall pairwise accuracy: {baseline_results['overall_accuracy']:.4f}")
    print(f"  Number of pairs evaluated: {baseline_results['n_pairs']}")
    print(f"  Total comparisons: {baseline_results['n_comparisons']}")

    print(f"\nPer-dimension breakdown:")
    for dim_idx, acc in sorted(baseline_results['per_dimension'].items()):
        dim_name = PERCEPIANO_DIMENSIONS[dim_idx]
        indicator = "!" if acc > 0.55 else " "
        print(f"  {indicator} {dim_idx:2d}. {dim_name:<25}: {acc:.4f}")

    # Summary statistics
    dim_accs = list(baseline_results['per_dimension'].values())
    print(f"\nSummary:")
    print(f"  Mean dimension accuracy: {np.mean(dim_accs):.4f}")
    print(f"  Std dimension accuracy: {np.std(dim_accs):.4f}")
    print(f"  Best dimension: {PERCEPIANO_DIMENSIONS[max(baseline_results['per_dimension'], key=baseline_results['per_dimension'].get)]}")
    print(f"  Worst dimension: {PERCEPIANO_DIMENSIONS[min(baseline_results['per_dimension'], key=baseline_results['per_dimension'].get)]}")

    # Save baseline results
    e1a_results = {
        "experiment_id": "E1a_muq_baseline",
        "description": "MuQ regression model pairwise ranking accuracy",
        "summary": baseline_results,
    }
    with open(PATHS["results"] / "E1a_muq_baseline.json", "w") as f:
        json.dump(e1a_results, f, indent=2)
    print(f"\nResults saved to {PATHS['results'] / 'E1a_muq_baseline.json'}")
else:
    print("No MuQ checkpoint found. Run Paper 1 training first.")
    print("Expected path:", muq_checkpoint)
    baseline_results = None

## 4. Approach A: Contrastive Pairwise Ranking

**Architecture**: Shared encoder + projection head + ranking heads

**Loss**: `L_total = L_ranking + lambda * L_infonce`
- InfoNCE with piece-based positives (same piece = positive, different piece = negative)
- Margin-based pairwise ranking loss

In [None]:
# Model factory for Approach A
def make_approach_a_model(config):
    return ContrastivePairwiseRankingModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        projection_dim=config.get("projection_dim", 256),
        num_labels=config.get("num_labels", 19),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        temperature=config.get("temperature", 0.07),
        lambda_contrastive=config.get("lambda_contrastive", 0.3),
        margin=config.get("margin", 0.2),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.0),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

In [None]:
# E2a: Approach A with default config
results_e2a = run_pairwise_experiment(
    exp_id="E2a_contrastive_default",
    description="Approach A: Contrastive pairwise ranking with default config",
    model_factory=make_approach_a_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=APPROACH_A_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E2a_contrastive_default")

In [None]:
# E2b: Temperature ablation
temperature_values = [0.05, 0.07, 0.1, 0.2]
temp_results = {}

for temp in temperature_values:
    exp_id = f"E2b_temp_{temp}"
    config = {**APPROACH_A_CONFIG, "temperature": temp}
    result = run_pairwise_experiment(
        exp_id=exp_id,
        description=f"Approach A: temperature={temp}",
        model_factory=make_approach_a_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    temp_results[temp] = result["summary"]["avg_pairwise_acc"]

print("\nTemperature ablation results:")
for temp, acc in temp_results.items():
    print(f"  tau={temp}: {acc:.4f}")

In [None]:
# E2c: Projection dimension ablation
proj_dims = [64, 128, 256, 512]
proj_results = {}

for dim in proj_dims:
    exp_id = f"E2c_proj_{dim}"
    config = {**APPROACH_A_CONFIG, "projection_dim": dim}
    result = run_pairwise_experiment(
        exp_id=exp_id,
        description=f"Approach A: projection_dim={dim}",
        model_factory=make_approach_a_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    proj_results[dim] = result["summary"]["avg_pairwise_acc"]

print("\nProjection dim ablation results:")
for dim, acc in proj_results.items():
    print(f"  dim={dim}: {acc:.4f}")

In [None]:
# E2d: Lambda contrastive ablation
lambda_values = [0.0, 0.1, 0.3, 0.5, 1.0]
lambda_results = {}

for lam in lambda_values:
    exp_id = f"E2d_lambda_{lam}"
    config = {**APPROACH_A_CONFIG, "lambda_contrastive": lam}
    result = run_pairwise_experiment(
        exp_id=exp_id,
        description=f"Approach A: lambda_contrastive={lam}",
        model_factory=make_approach_a_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    lambda_results[lam] = result["summary"]["avg_pairwise_acc"]

print("\nLambda contrastive ablation results:")
for lam, acc in lambda_results.items():
    print(f"  lambda={lam}: {acc:.4f}")

## 5. Approach B: Siamese Dimension-Specific Ranking

**Architecture**: Shared encoder (both inputs) + comparison module + 19 dimension heads

**Comparison**: `[z_A; z_B; z_A - z_B; z_A * z_B]` or bilinear

**Loss**: BCE with label smoothing, ignoring ambiguous pairs

In [None]:
# Model factory for Approach B
def make_approach_b_model(config):
    return SiameseDimensionRankingModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        num_labels=config.get("num_labels", 19),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        comparison_type=config.get("comparison_type", "concat_diff"),
        margin=config.get("margin", 0.3),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.05),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

In [None]:
# E3a: Approach B with default config
results_e3a = run_pairwise_experiment(
    exp_id="E3a_siamese_default",
    description="Approach B: Siamese dimension-specific ranking with default config",
    model_factory=make_approach_b_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=APPROACH_B_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E3a_siamese_default")

In [None]:
# E3b: Comparison type ablation
comparison_types = ["concat_diff", "bilinear"]
comp_results = {}

for comp_type in comparison_types:
    exp_id = f"E3b_{comp_type}"
    config = {**APPROACH_B_CONFIG, "comparison_type": comp_type}
    result = run_pairwise_experiment(
        exp_id=exp_id,
        description=f"Approach B: comparison_type={comp_type}",
        model_factory=make_approach_b_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    comp_results[comp_type] = result["summary"]["avg_pairwise_acc"]

print("\nComparison type ablation results:")
for comp, acc in comp_results.items():
    print(f"  {comp}: {acc:.4f}")

In [None]:
# E3c: Label smoothing ablation
smoothing_values = [0.0, 0.05, 0.1, 0.15]
smooth_results = {}

for smooth in smoothing_values:
    exp_id = f"E3c_smooth_{smooth}"
    config = {**APPROACH_B_CONFIG, "label_smoothing": smooth}
    result = run_pairwise_experiment(
        exp_id=exp_id,
        description=f"Approach B: label_smoothing={smooth}",
        model_factory=make_approach_b_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    smooth_results[smooth] = result["summary"]["avg_pairwise_acc"]

print("\nLabel smoothing ablation results:")
for smooth, acc in smooth_results.items():
    print(f"  smoothing={smooth}: {acc:.4f}")

## 6. Approach C: Disentangled Dual-Encoder

**Architecture**: Piece encoder + Style encoder + GRL + adversarial piece classifier

**Loss**: `L_total = L_regression + lambda_adv * L_adversarial`
- Gradient Reversal Layer makes style encoder adversarial to piece classification
- Style encoder feeds dimension prediction heads

In [None]:
# Model factory for Approach C
def make_approach_c_model(config):
    return DisentangledDualEncoderModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        num_labels=config.get("num_labels", 19),
        num_pieces=config.get("num_pieces", 206),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        lambda_adversarial=config.get("lambda_adversarial", 0.5),
        grl_schedule=config.get("grl_schedule", "linear"),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

In [None]:
# E4a: Approach C with default config
results_e4a = run_disentanglement_experiment(
    exp_id="E4a_dual_encoder_default",
    description="Approach C: Disentangled dual-encoder with default config",
    model_factory=make_approach_c_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=APPROACH_C_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E4a_dual_encoder_default")

In [None]:
# E4b: Adversarial weight ablation
adv_weights = [0.1, 0.3, 0.5, 0.7]
adv_results = {}

for weight in adv_weights:
    exp_id = f"E4b_adv_{weight}"
    config = {**APPROACH_C_CONFIG, "lambda_adversarial": weight}
    result = run_disentanglement_experiment(
        exp_id=exp_id,
        description=f"Approach C: lambda_adversarial={weight}",
        model_factory=make_approach_c_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    adv_results[weight] = {
        "r2": result["summary"]["avg_r2"],
        "style_piece_acc": result["disentanglement"]["style_piece_accuracy"],
        "intra_piece_std": result["disentanglement"]["intra_piece_std"],
    }

print("\nAdversarial weight ablation results:")
for w, metrics in adv_results.items():
    print(f"  lambda={w}: R2={metrics['r2']:.4f}, style_piece_acc={metrics['style_piece_acc']:.4f}")

In [None]:
# E4c: GRL schedule ablation
schedules = ["constant", "linear", "cosine"]
sched_results = {}

for sched in schedules:
    exp_id = f"E4c_sched_{sched}"
    config = {**APPROACH_C_CONFIG, "grl_schedule": sched}
    result = run_disentanglement_experiment(
        exp_id=exp_id,
        description=f"Approach C: grl_schedule={sched}",
        model_factory=make_approach_c_model,
        cache_dir=PATHS["muq_cache"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=config,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment(exp_id)
    sched_results[sched] = {
        "r2": result["summary"]["avg_r2"],
        "style_piece_acc": result["disentanglement"]["style_piece_accuracy"],
    }

print("\nGRL schedule ablation results:")
for s, metrics in sched_results.items():
    print(f"  {s}: R2={metrics['r2']:.4f}, style_piece_acc={metrics['style_piece_acc']:.4f}")

## 7. Combination Experiments

In [None]:
# Model factories for combinations
def make_ac_model(config):
    """A+C: Contrastive + Adversarial"""
    return ContrastiveDisentangledModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        projection_dim=config.get("projection_dim", 256),
        num_labels=config.get("num_labels", 19),
        num_pieces=config.get("num_pieces", 206),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        temperature=config.get("temperature", 0.07),
        lambda_contrastive=config.get("lambda_contrastive", 0.3),
        lambda_adversarial=config.get("lambda_adversarial", 0.5),
        grl_schedule=config.get("grl_schedule", "linear"),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.0),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

def make_bc_model(config):
    """B+C: Siamese + Adversarial"""
    return SiameseDisentangledModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        num_labels=config.get("num_labels", 19),
        num_pieces=config.get("num_pieces", 206),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        lambda_adversarial=config.get("lambda_adversarial", 0.5),
        grl_schedule=config.get("grl_schedule", "linear"),
        comparison_type=config.get("comparison_type", "concat_diff"),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.05),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

def make_abc_model(config):
    """A+B+C: Full combination"""
    return FullCombinedModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        projection_dim=config.get("projection_dim", 256),
        num_labels=config.get("num_labels", 19),
        num_pieces=config.get("num_pieces", 206),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        temperature=config.get("temperature", 0.07),
        lambda_contrastive=config.get("lambda_contrastive", 0.3),
        lambda_adversarial=config.get("lambda_adversarial", 0.5),
        grl_schedule=config.get("grl_schedule", "linear"),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.05),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

In [None]:
# Combined config
COMBINED_CONFIG = {
    **APPROACH_A_CONFIG,
    **APPROACH_B_CONFIG,
    **APPROACH_C_CONFIG,
}

In [None]:
# E5c: A+C (Contrastive + Adversarial)
results_e5c = run_pairwise_experiment(
    exp_id="E5c_contrastive_adversarial",
    description="A+C: Contrastive + Adversarial disentanglement",
    model_factory=make_ac_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=COMBINED_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E5c_contrastive_adversarial")

In [None]:
# E5b: B+C (Siamese + Adversarial)
results_e5b = run_pairwise_experiment(
    exp_id="E5b_siamese_adversarial",
    description="B+C: Siamese + Adversarial disentanglement",
    model_factory=make_bc_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=COMBINED_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E5b_siamese_adversarial")

In [None]:
# E5d: A+B+C (Full combination)
results_e5d = run_pairwise_experiment(
    exp_id="E5d_full_combined",
    description="A+B+C: Full combination of all approaches",
    model_factory=make_abc_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=COMBINED_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E5d_full_combined")

## 8. Analysis & Visualization

In [None]:
# Aggregate all results
results_files = list(PATHS["results"].glob("*.json"))
all_results = []

for f in results_files:
    with open(f) as fp:
        result = json.load(fp)
        exp_id = result["experiment_id"]

        # Extract key metrics
        row = {
            "exp_id": exp_id,
            "description": result.get("description", ""),
        }

        if "avg_pairwise_acc" in result.get("summary", {}):
            row["pairwise_acc"] = result["summary"]["avg_pairwise_acc"]
            row["acc_std"] = result["summary"].get("std_pairwise_acc", 0)

        if "avg_r2" in result.get("summary", {}):
            row["r2"] = result["summary"]["avg_r2"]
            row["r2_std"] = result["summary"].get("std_r2", 0)

        if "disentanglement" in result:
            row["style_piece_acc"] = result["disentanglement"].get("style_piece_accuracy", None)
            row["intra_piece_std"] = result["disentanglement"].get("intra_piece_std", None)

        all_results.append(row)

results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values("exp_id")
print(f"Loaded {len(results_df)} experiment results")
results_df

In [None]:
# Main results comparison
main_experiments = [
    "E2a_contrastive_default",
    "E3a_siamese_default",
    "E4a_dual_encoder_default",
    "E5b_siamese_adversarial",
    "E5c_contrastive_adversarial",
    "E5d_full_combined",
]

main_df = results_df[results_df["exp_id"].isin(main_experiments)].copy()
main_df["approach"] = main_df["exp_id"].map({
    "E2a_contrastive_default": "A: Contrastive",
    "E3a_siamese_default": "B: Siamese",
    "E4a_dual_encoder_default": "C: Disentangled",
    "E5b_siamese_adversarial": "B+C",
    "E5c_contrastive_adversarial": "A+C",
    "E5d_full_combined": "A+B+C",
})

print("Main Results:")
print(main_df[["approach", "pairwise_acc", "acc_std", "r2", "style_piece_acc"]].to_string(index=False))

In [None]:
# Per-dimension accuracy comparison
# Load full results for per-dimension analysis
per_dim_data = []

for exp_id in ["E2a_contrastive_default", "E3a_siamese_default", "E5d_full_combined"]:
    result_file = PATHS["results"] / f"{exp_id}.json"
    if result_file.exists():
        with open(result_file) as f:
            result = json.load(f)
            per_dim = result.get("per_dimension", {})
            for dim_idx, acc in per_dim.items():
                per_dim_data.append({
                    "experiment": exp_id,
                    "dimension": PERCEPIANO_DIMENSIONS[int(dim_idx)],
                    "accuracy": acc,
                })

per_dim_df = pd.DataFrame(per_dim_data)
if len(per_dim_df) > 0:
    pivot = per_dim_df.pivot(index="dimension", columns="experiment", values="accuracy")
    print("Per-dimension accuracy:")
    print(pivot.to_string())

In [None]:
# Ablation heatmaps
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Temperature ablation
if 'temp_results' in dir():
    ax = axes[0]
    temps = list(temp_results.keys())
    accs = list(temp_results.values())
    ax.bar(range(len(temps)), accs)
    ax.set_xticks(range(len(temps)))
    ax.set_xticklabels([str(t) for t in temps])
    ax.set_xlabel("Temperature")
    ax.set_ylabel("Pairwise Accuracy")
    ax.set_title("E2b: Temperature Ablation")
    ax.axhline(y=0.5, color='r', linestyle='--', label='Random')

# Lambda ablation
if 'lambda_results' in dir():
    ax = axes[1]
    lams = list(lambda_results.keys())
    accs = list(lambda_results.values())
    ax.bar(range(len(lams)), accs)
    ax.set_xticks(range(len(lams)))
    ax.set_xticklabels([str(l) for l in lams])
    ax.set_xlabel("Lambda Contrastive")
    ax.set_ylabel("Pairwise Accuracy")
    ax.set_title("E2d: Lambda Ablation")
    ax.axhline(y=0.5, color='r', linestyle='--', label='Random')

# Adversarial weight ablation
if 'adv_results' in dir():
    ax = axes[2]
    weights = list(adv_results.keys())
    r2s = [v["r2"] for v in adv_results.values()]
    ax.bar(range(len(weights)), r2s)
    ax.set_xticks(range(len(weights)))
    ax.set_xticklabels([str(w) for w in weights])
    ax.set_xlabel("Lambda Adversarial")
    ax.set_ylabel("R2")
    ax.set_title("E4b: Adversarial Weight Ablation")

plt.tight_layout()
plt.savefig(PATHS["results"].parent / "figures" / "ablation_heatmaps.png", dpi=150)
plt.show()

## 9. Paper Figures

In [None]:
# Create figures directory
figures_dir = PATHS["results"].parent / "figures"
figures_dir.mkdir(exist_ok=True)

In [None]:
# Fig 1: Main results bar chart
fig, ax = plt.subplots(figsize=(10, 6))

if len(main_df) > 0:
    approaches = main_df["approach"].tolist()
    accs = main_df["pairwise_acc"].tolist()
    stds = main_df["acc_std"].fillna(0).tolist()

    x = range(len(approaches))
    bars = ax.bar(x, accs, yerr=stds, capsize=5, color='steelblue', edgecolor='black')

    ax.set_xticks(x)
    ax.set_xticklabels(approaches, rotation=45, ha='right')
    ax.set_ylabel("Pairwise Ranking Accuracy")
    ax.set_title("Disentanglement Approaches: Pairwise Ranking Accuracy")
    ax.axhline(y=0.5, color='red', linestyle='--', label='Random Baseline (50%)')
    ax.legend()
    ax.set_ylim(0.4, 0.8)

    # Add value labels
    for bar, acc in zip(bars, accs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig(figures_dir / "fig1_main_results.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Fig 3: t-SNE of piece vs style embeddings (for Approach C)
# This requires loading a trained model and computing embeddings
from sklearn.manifold import TSNE

# Load best Approach C model
best_c_ckpt = PATHS["checkpoints"] / "E4a_dual_encoder_default" / "fold0_best.ckpt"

if best_c_ckpt.exists():
    print("Loading model for t-SNE visualization...")
    model = DisentangledDualEncoderModel.load_from_checkpoint(best_c_ckpt)
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    # Create dataset for fold 0 validation
    _, val_keys = get_fold_piece_mapping(
        MULTI_PERFORMER_PIECES, PIECE_STRATIFIED_FOLDS, 0, "val"
    )
    val_ds = DisentanglementDataset(
        PATHS["muq_cache"], labels, MULTI_PERFORMER_PIECES, val_keys
    )
    val_dl = torch.utils.data.DataLoader(
        val_ds, batch_size=32, collate_fn=disentanglement_collate_fn
    )

    # Extract embeddings
    z_pieces, z_styles, piece_ids = [], [], []
    with torch.no_grad():
        for batch in val_dl:
            outputs = model(
                batch["embeddings"].to(device),
                batch["attention_mask"].to(device),
            )
            z_pieces.append(outputs["z_piece"].cpu())
            z_styles.append(outputs["z_style"].cpu())
            piece_ids.extend(batch["piece_ids"].tolist())

    z_piece = torch.cat(z_pieces).numpy()
    z_style = torch.cat(z_styles).numpy()
    piece_ids = np.array(piece_ids)

    # t-SNE
    print("Computing t-SNE...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    z_piece_2d = tsne.fit_transform(z_piece)
    z_style_2d = tsne.fit_transform(z_style)

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Piece embeddings (should cluster by piece)
    ax = axes[0]
    scatter = ax.scatter(z_piece_2d[:, 0], z_piece_2d[:, 1], c=piece_ids, cmap='tab20', alpha=0.7, s=20)
    ax.set_title("Piece Encoder Embeddings (z_piece)")
    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")

    # Style embeddings (should NOT cluster by piece)
    ax = axes[1]
    scatter = ax.scatter(z_style_2d[:, 0], z_style_2d[:, 1], c=piece_ids, cmap='tab20', alpha=0.7, s=20)
    ax.set_title("Style Encoder Embeddings (z_style)")
    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")

    plt.suptitle("Disentanglement: Piece vs Style Representations", fontsize=14)
    plt.tight_layout()
    plt.savefig(figures_dir / "fig3_tsne_embeddings.png", dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No trained model found for t-SNE visualization.")

In [None]:
# Fig 4: Per-dimension breakdown
fig, ax = plt.subplots(figsize=(14, 6))

if len(per_dim_df) > 0:
    pivot = per_dim_df.pivot(index="dimension", columns="experiment", values="accuracy")
    pivot = pivot.reindex(PERCEPIANO_DIMENSIONS)  # Ensure consistent order

    x = np.arange(len(PERCEPIANO_DIMENSIONS))
    width = 0.25

    for i, col in enumerate(pivot.columns):
        offset = (i - len(pivot.columns)/2 + 0.5) * width
        label = col.replace("_default", "").replace("_", " ").title()
        ax.bar(x + offset, pivot[col].fillna(0.5), width, label=label)

    ax.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Random')
    ax.set_xticks(x)
    ax.set_xticklabels(PERCEPIANO_DIMENSIONS, rotation=45, ha='right')
    ax.set_ylabel("Pairwise Accuracy")
    ax.set_title("Per-Dimension Pairwise Ranking Accuracy")
    ax.legend(loc='upper right')
    ax.set_ylim(0.3, 0.9)

plt.tight_layout()
plt.savefig(figures_dir / "fig4_per_dimension.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Upload results and checkpoints to remote storage
print("Uploading results to remote storage...")

# Upload checkpoints
print("\nUploading checkpoints...")
rclone_sync(str(PATHS["checkpoints"]), REMOTE_PATHS["checkpoints"], direction="upload")

# Upload results JSON files
print("\nUploading results...")
rclone_sync(str(PATHS["results"]), REMOTE_PATHS["results"], direction="upload")

# Upload figures if they exist
figures_dir = PATHS["results"].parent / "figures"
if figures_dir.exists():
    print("\nUploading figures...")
    rclone_sync(str(figures_dir), "figures/disentanglement", direction="upload")

print("\nUpload complete! Results saved to remote storage.")

## 10. Extended Experiments: E7, E8, E11, E16

Additional experiments to strengthen disentanglement results:
- **E7**: Hard pair mining (focus on challenging pairs)
- **E8**: Per-dimension model groups
- **E11**: Triplet loss for performer discrimination
- **E16**: 4096-dim embedding comparison (concatenated layers)

In [None]:
# E7: Hard Pair Mining Experiments
# Focus training on challenging pairs with moderate score differences
from disentanglement.data import sample_hard_pairs

# Analyze difficulty distribution
print("E7: Hard Pair Mining")
print("="*50)

# Sample pairs at different difficulty levels
difficulty_ranges = [
    (0.05, 0.10, "easy"),
    (0.10, 0.20, "medium"),
    (0.20, 0.30, "hard"),
]

for min_diff, max_diff, name in difficulty_ranges:
    hard_pairs = sample_hard_pairs(
        MULTI_PERFORMER_PIECES, labels, 
        n_pairs=1000, min_diff=min_diff, max_diff=max_diff, seed=42
    )
    print(f"  {name} pairs ({min_diff:.2f}-{max_diff:.2f}): {len(hard_pairs)} samples")

In [None]:
# E7a: Train with hard pairs only (diff 0.05-0.20)
# Uses the best-performing base model with hard pair sampling
print("E7a: Hard Pair Training")
print("="*50)

# HardPairRankingDataset is now imported from disentanglement.data
# It filters pairs to the specified difficulty range during initialization

# Configuration for hard pair experiment
E7A_CONFIG = {
    **APPROACH_B_CONFIG,
    "min_diff": 0.05,
    "max_diff": 0.20,
}

# Custom runner that uses HardPairRankingDataset
def run_hard_pair_experiment(
    exp_id, description, model_factory, cache_dir, labels,
    piece_to_keys, fold_assignments, config, checkpoint_root,
    results_dir, log_dir, on_fold_complete=None
):
    """Run pairwise experiment with hard pair filtering."""
    from disentanglement.training.runner import (
        experiment_completed, load_existing_results, get_fold_piece_mapping
    )
    from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
    from pytorch_lightning.loggers import CSVLogger
    from torch.utils.data import DataLoader
    import time

    exp_checkpoint_dir = Path(checkpoint_root) / exp_id
    exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    results_dir = Path(results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    existing = load_existing_results(exp_id, results_dir)
    if existing and experiment_completed(exp_id, checkpoint_root):
        print(f"SKIP {exp_id}: already completed")
        return existing

    print(f"\n{'='*70}")
    print(f"EXPERIMENT: {exp_id}")
    print(f"Description: {description}")
    print(f"{'='*70}")

    start_time = time.time()
    fold_results = {}
    all_logits, all_labels_a, all_labels_b = [], [], []

    for fold in range(config.get("n_folds", 4)):
        ckpt_path = exp_checkpoint_dir / f"fold{fold}_best.ckpt"

        train_piece_map, train_keys = get_fold_piece_mapping(
            piece_to_keys, fold_assignments, fold, "train"
        )
        val_piece_map, val_keys = get_fold_piece_mapping(
            piece_to_keys, fold_assignments, fold, "val"
        )

        # Use HardPairRankingDataset instead of regular dataset
        train_ds = HardPairRankingDataset(
            cache_dir, labels, train_piece_map, train_keys,
            max_frames=config.get("max_frames", 1000),
            ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
            min_diff=config.get("min_diff", 0.05),
            max_diff=config.get("max_diff", 0.20),
        )
        val_ds = HardPairRankingDataset(
            cache_dir, labels, val_piece_map, val_keys,
            max_frames=config.get("max_frames", 1000),
            ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
            min_diff=config.get("min_diff", 0.05),
            max_diff=config.get("max_diff", 0.20),
        )

        if len(train_ds) == 0 or len(val_ds) == 0:
            print(f"Fold {fold}: No hard pairs available, skipping")
            continue

        train_dl = DataLoader(
            train_ds, batch_size=config.get("batch_size", 32),
            shuffle=True, collate_fn=pairwise_collate_fn,
            num_workers=config.get("num_workers", 2), pin_memory=True,
        )
        val_dl = DataLoader(
            val_ds, batch_size=config.get("batch_size", 32),
            shuffle=False, collate_fn=pairwise_collate_fn,
            num_workers=config.get("num_workers", 2), pin_memory=True,
        )

        print(f"Fold {fold}: {len(train_ds)} train hard pairs, {len(val_ds)} val hard pairs")

        config_with_pieces = {**config, "num_pieces": train_ds.get_num_pieces()}

        if ckpt_path.exists():
            model = model_factory(config_with_pieces)
            model = model.__class__.load_from_checkpoint(ckpt_path)
        else:
            model = model_factory(config_with_pieces)
            callbacks = [
                ModelCheckpoint(
                    dirpath=exp_checkpoint_dir, filename=f"fold{fold}_best",
                    monitor="val_pairwise_acc", mode="max", save_top_k=1,
                ),
                EarlyStopping(
                    monitor="val_pairwise_acc", mode="max",
                    patience=config.get("patience", 15), verbose=True,
                ),
            ]
            trainer = pl.Trainer(
                max_epochs=config.get("max_epochs", 100),
                callbacks=callbacks,
                logger=CSVLogger(save_dir=log_dir, name=exp_id, version=f"fold{fold}"),
                accelerator="auto", devices=1,
                gradient_clip_val=config.get("gradient_clip_val", 1.0),
                enable_progress_bar=True, deterministic=True, log_every_n_steps=10,
            )
            trainer.fit(model, train_dl, val_dl)
            fold_results[fold] = float(callbacks[0].best_model_score or 0.5)
            model = model.__class__.load_from_checkpoint(ckpt_path)
            if on_fold_complete:
                on_fold_complete(exp_id, fold)

        # Evaluate
        model.eval().to("cuda" if torch.cuda.is_available() else "cpu")
        device = next(model.parameters()).device
        with torch.no_grad():
            for batch in val_dl:
                outputs = model(
                    batch["embeddings_a"].to(device),
                    batch["embeddings_b"].to(device),
                    batch.get("mask_a").to(device),
                    batch.get("mask_b").to(device),
                )
                logits = outputs if not isinstance(outputs, dict) else outputs.get("ranking_logits", outputs)
                all_logits.append(logits.cpu().numpy())
                all_labels_a.append(batch["labels_a"].numpy())
                all_labels_b.append(batch["labels_b"].numpy())
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Aggregate
    all_logits = np.vstack(all_logits)
    all_labels_a = np.vstack(all_labels_a)
    all_labels_b = np.vstack(all_labels_b)
    metrics = compute_pairwise_metrics(all_logits, all_labels_a, all_labels_b)

    if not fold_results:
        fold_results = {i: metrics["overall_accuracy"] for i in range(4)}

    avg_acc = np.mean(list(fold_results.values()))
    std_acc = np.std(list(fold_results.values()))

    results = {
        "experiment_id": exp_id,
        "description": description,
        "config": {k: v for k, v in config.items() if not callable(v)},
        "summary": {
            "avg_pairwise_acc": float(avg_acc),
            "std_pairwise_acc": float(std_acc),
            "overall_accuracy": metrics["overall_accuracy"],
            "n_comparisons": metrics["n_comparisons"],
        },
        "fold_results": {str(k): float(v) for k, v in fold_results.items()},
        "per_dimension": metrics["per_dimension"],
        "training_time_seconds": time.time() - start_time,
    }

    with open(results_dir / f"{exp_id}.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"\n{exp_id} COMPLETE: Acc={avg_acc:.4f} +/- {std_acc:.4f}")
    return results

# Run E7a experiment
results_e7a = run_hard_pair_experiment(
    exp_id="E7a_hard_pairs",
    description="Hard pair mining: train on pairs with 0.05-0.20 mean diff",
    model_factory=make_approach_b_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=E7A_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E7a_hard_pairs")

In [None]:
# E8: Per-Dimension Model Groups
# Train specialized models for different dimension categories
print("E8: Per-Dimension Model Groups")
print("="*50)

# Dimension groups based on Paper 1 analysis
DIMENSION_GROUPS = {
    "expression": [0, 3, 6],      # timing, dynamic_range, articulation_touch
    "technical": [8, 11, 13],     # pedal_clarity, balance, timbre_brightness
    "interpretive": [1, 4, 7],    # drama, mood_valence, interpretation
    "structural": [2, 5, 9],      # phrasing, tempo_stability, voicing
    "aesthetic": [10, 12, 14],    # tone_quality, expressiveness, overall_quality
    "nuance": [15, 16, 17, 18],   # remaining dimensions
}

print("Dimension groups:")
for group_name, dim_indices in DIMENSION_GROUPS.items():
    dim_names = [PERCEPIANO_DIMENSIONS[i] for i in dim_indices]
    print(f"  {group_name}: {dim_names}")

In [None]:
# E8a: Train separate models per dimension group
# Uses run_dimension_group_experiment which trains specialized models for each category

print("E8a: Per-Dimension Group Training")
print("="*50)

# Model factory with dimension_indices support
def make_approach_b_grouped(config):
    return SiameseDimensionRankingModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        num_labels=config.get("num_labels", 19),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        comparison_type=config.get("comparison_type", "concat_diff"),
        margin=config.get("margin", 0.3),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.05),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
        dimension_indices=config.get("dimension_indices"),
    )

# Run E8a: Per-dimension group experiment
results_e8a = run_dimension_group_experiment(
    exp_id="E8a_dimension_groups",
    description="Per-dimension group specialized models",
    model_factory=make_approach_b_grouped,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=APPROACH_B_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    dimension_groups=DIMENSION_GROUPS,
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E8a_dimension_groups")

# Print per-group results
print("\nPer-group accuracy breakdown:")
for group_name, group_result in results_e8a.get("group_results", {}).items():
    acc = group_result.get("avg_pairwise_acc", 0)
    dims = group_result.get("dimensions", [])
    print(f"  {group_name}: {acc:.4f} (dims: {dims})")

In [None]:
# E11: Triplet Loss for Performer Discrimination
# Uses triplet sampling within same-piece performances

print("E11: Triplet Loss for Performer Discrimination")
print("="*50)

from disentanglement import TripletRankingModel
from disentanglement.data import TripletRankingDataset, triplet_collate_fn

# Configuration for triplet model
E11_CONFIG = {
    "input_dim": 1024,
    "hidden_dim": 512,
    "embedding_dim": 256,
    "num_labels": 19,
    "dropout": 0.2,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "margin": 0.5,
    "lambda_ranking": 0.5,
    "ambiguous_threshold": 0.05,
    "pooling": "attention",
    "distance_fn": "euclidean",
    "max_epochs": 100,
    "batch_size": 32,
    "max_frames": 1000,
}

# Test triplet dataset creation
train_piece_map, train_keys = get_fold_piece_mapping(
    MULTI_PERFORMER_PIECES, PIECE_STRATIFIED_FOLDS, fold_id=0, mode="train"
)

triplet_ds = TripletRankingDataset(
    PATHS["muq_cache"],
    labels,
    train_piece_map,
    train_keys,
    max_frames=100,  # Small for testing
    min_score_diff=0.05,
)

print(f"Triplet dataset: {len(triplet_ds)} triplets")
print(f"Num pieces with 3+ performers: {triplet_ds.get_num_pieces()}")

In [None]:
# E11a: Train triplet ranking model
# Uses run_triplet_experiment with TripletRankingDataset

print("E11a: Triplet Ranking Model Training")
print("="*50)

def make_triplet_model(config):
    return TripletRankingModel(
        input_dim=config.get("input_dim", 1024),
        hidden_dim=config.get("hidden_dim", 512),
        embedding_dim=config.get("embedding_dim", 256),
        num_labels=config.get("num_labels", 19),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        margin=config.get("margin", 0.5),
        lambda_ranking=config.get("lambda_ranking", 0.5),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        pooling=config.get("pooling", "attention"),
        distance_fn=config.get("distance_fn", "euclidean"),
        max_epochs=config.get("max_epochs", 100),
    )

# Test model creation
test_model = make_triplet_model(E11_CONFIG)
print(f"TripletRankingModel parameters: {sum(p.numel() for p in test_model.parameters()):,}")

# Run E11a experiment using run_triplet_experiment
results_e11a = run_triplet_experiment(
    exp_id="E11a_triplet_ranking",
    description="Triplet loss for performer discrimination",
    model_factory=make_triplet_model,
    cache_dir=PATHS["muq_cache"],
    labels=labels,
    piece_to_keys=MULTI_PERFORMER_PIECES,
    fold_assignments=PIECE_STRATIFIED_FOLDS,
    config=E11_CONFIG,
    checkpoint_root=PATHS["checkpoints"],
    results_dir=PATHS["results"],
    log_dir=PATHS["logs"],
    on_fold_complete=upload_fold_checkpoint,
)
upload_experiment("E11a_triplet_ranking")

print(f"\nE11a Results: Pairwise Acc = {results_e11a['summary']['avg_pairwise_acc']:.4f}")

In [None]:
# E16: 4096-dim Embedding Extraction and Comparison
# Compare 1024-dim (averaged layers) vs 4096-dim (concatenated layers)
print("E16: Embedding Dimension Comparison (1024 vs 4096)")
print("="*50)

# Path for 4096-dim embeddings
PATHS["muq_cache_4096"] = Path("/workspace/data/cache/muq_embeddings_4096")

# Check if 4096-dim cache exists, extract if needed
if not PATHS["muq_cache_4096"].exists() or len(list(PATHS["muq_cache_4096"].glob("*.pt"))) == 0:
    print("Extracting 4096-dim embeddings (concatenated layers 9-12)...")
    PATHS["muq_cache_4096"].mkdir(parents=True, exist_ok=True)

    # Import the extraction function
    from audio_experiments.extractors.muq import extract_muq_embeddings

    # Get all keys that need extraction
    all_keys = list(labels.keys())

    # Note: This requires the audio files to be available
    # The extraction concatenates layers 9-12 (4 layers * 1024 = 4096 dims)
    try:
        extract_muq_embeddings(
            audio_dir=Path("/workspace/data/audio"),
            cache_dir=PATHS["muq_cache_4096"],
            keys=all_keys,
            layer_start=9,
            layer_end=13,  # exclusive, so 9,10,11,12 = 4 layers
            layer_aggregation="concat",
        )
        print(f"Extracted {len(list(PATHS['muq_cache_4096'].glob('*.pt')))} 4096-dim embeddings")
    except Exception as e:
        print(f"Extraction failed: {e}")
        print("Audio files may not be available. Skipping 4096-dim extraction.")
else:
    n_cached = len(list(PATHS["muq_cache_4096"].glob("*.pt")))
    print(f"4096-dim cache exists with {n_cached} files")

In [None]:
# E16a: Run experiments with 4096-dim embeddings
# Uses same approaches but with higher-dimensional input
print("E16a: 4096-dim Siamese Ranking")
print("="*50)

E16_CONFIG = {
    **APPROACH_B_CONFIG,
    "input_dim": 4096,  # 4 layers * 1024 = 4096
    "hidden_dim": 1024,  # Scale up hidden dim proportionally
}

# Model factory for 4096-dim input
def make_approach_b_4096(config):
    return SiameseDimensionRankingModel(
        input_dim=config.get("input_dim", 4096),
        hidden_dim=config.get("hidden_dim", 1024),
        num_labels=config.get("num_labels", 19),
        dropout=config.get("dropout", 0.2),
        learning_rate=config.get("learning_rate", 1e-4),
        weight_decay=config.get("weight_decay", 1e-5),
        comparison_type=config.get("comparison_type", "concat_diff"),
        margin=config.get("margin", 0.3),
        ambiguous_threshold=config.get("ambiguous_threshold", 0.05),
        label_smoothing=config.get("label_smoothing", 0.05),
        pooling=config.get("pooling", "attention"),
        max_epochs=config.get("max_epochs", 100),
    )

# Check if we can run the experiment
if PATHS["muq_cache_4096"].exists() and len(list(PATHS["muq_cache_4096"].glob("*.pt"))) > 0:
    print("Running E16a with 4096-dim embeddings...")

    results_e16a = run_pairwise_experiment(
        exp_id="E16a_siamese_4096",
        description="Siamese ranking with 4096-dim concatenated embeddings",
        model_factory=make_approach_b_4096,
        cache_dir=PATHS["muq_cache_4096"],
        labels=labels,
        piece_to_keys=MULTI_PERFORMER_PIECES,
        fold_assignments=PIECE_STRATIFIED_FOLDS,
        config=E16_CONFIG,
        checkpoint_root=PATHS["checkpoints"],
        results_dir=PATHS["results"],
        log_dir=PATHS["logs"],
        on_fold_complete=upload_fold_checkpoint,
    )
    upload_experiment("E16a_siamese_4096")

    print(f"\nE16a Results: Pairwise Acc = {results_e16a['summary']['avg_pairwise_acc']:.4f}")

    # Compare with 1024-dim baseline
    e3a_file = PATHS["results"] / "E3a_siamese_default.json"
    if e3a_file.exists():
        with open(e3a_file) as f:
            e3a_results = json.load(f)
        baseline_acc = e3a_results["summary"]["avg_pairwise_acc"]
        e16a_acc = results_e16a["summary"]["avg_pairwise_acc"]
        improvement = e16a_acc - baseline_acc
        print(f"  1024-dim baseline: {baseline_acc:.4f}")
        print(f"  4096-dim result:   {e16a_acc:.4f}")
        print(f"  Improvement:       {improvement:+.4f}")
else:
    print("E16a requires 4096-dim embeddings. Extract them first (see cell above).")
    print("\nExpected outcome: Marginal improvement over 1024-dim")
    print("(More dimensions may help capture nuanced performance differences)")

## 11. Extended Results Summary

Expected outcomes from extended experiments:

| Experiment | Description | Expected Result |
|------------|-------------|-----------------|
| E1a | MuQ regression baseline | ~50-55% (piece bias) |
| E7a | Hard pair mining | +2-5% over default |
| E8a | Per-dimension groups | Expression > Interpretive |
| E11a | Triplet loss | Better performer variance |
| E16a | 4096-dim embeddings | Marginal improvement |

In [None]:
# Final results aggregation including extended experiments
print("All Experiment Results")
print("="*60)

# Reload results including new experiments
results_files = list(PATHS["results"].glob("*.json"))
all_results = []

for f in results_files:
    with open(f) as fp:
        result = json.load(fp)
        exp_id = result.get("experiment_id", f.stem)

        row = {
            "exp_id": exp_id,
            "description": result.get("description", ""),
        }
        # Handle different result formats
        summary = result.get("summary", {})
        if isinstance(summary, dict):
            if "overall_accuracy" in summary:
                row["pairwise_acc"] = summary["overall_accuracy"]
            elif "avg_pairwise_acc" in summary:
                row["pairwise_acc"] = summary["avg_pairwise_acc"]
                row["acc_std"] = summary.get("std_pairwise_acc", 0)

            if "avg_r2" in summary:
                row["r2"] = summary["avg_r2"]

        all_results.append(row)

final_df = pd.DataFrame(all_results).sort_values("exp_id")
print(f"\nLoaded {len(final_df)} experiment results")

# Display key results
if "pairwise_acc" in final_df.columns:
    print("\nTop performers by pairwise accuracy:")
    top_results = final_df.nlargest(5, "pairwise_acc")[["exp_id", "pairwise_acc", "description"]]
    print(top_results.to_string(index=False))

## 12. Pre-flight Verification

Verify all imports and model instantiation before running on cloud.

In [None]:
# Pre-flight verification: Verify all imports and model instantiation
print("Pre-flight Verification")
print("="*60)

# 1. Verify imports
print("\n1. Verifying imports...")
try:
    from disentanglement import (
        run_pairwise_experiment,
        run_disentanglement_experiment,
        run_triplet_experiment,
        run_dimension_group_experiment,
        HardPairRankingDataset,
        TripletRankingDataset,
        ContrastivePairwiseRankingModel,
        SiameseDimensionRankingModel,
        DisentangledDualEncoderModel,
        TripletRankingModel,
    )
    print("   All imports successful!")
except ImportError as e:
    print(f"   Import error: {e}")
    raise

# 2. Verify model instantiation
print("\n2. Verifying model instantiation...")
test_input = torch.randn(2, 100, 1024)  # [batch, seq, dim]

models_to_test = [
    ("ContrastivePairwiseRankingModel", ContrastivePairwiseRankingModel()),
    ("SiameseDimensionRankingModel", SiameseDimensionRankingModel()),
    ("SiameseDimensionRankingModel (grouped)", SiameseDimensionRankingModel(
        num_labels=3, dimension_indices=[0, 1, 2]
    )),
    ("DisentangledDualEncoderModel", DisentangledDualEncoderModel(num_pieces=10)),
    ("TripletRankingModel", TripletRankingModel()),
]

for name, model in models_to_test:
    try:
        model.eval()
        if "Triplet" in name:
            # Triplet model needs 3 inputs
            output = model(test_input, test_input, test_input)
        elif "Disentangled" in name:
            # Disentanglement model needs single input
            output = model(test_input)
        else:
            # Pairwise models need 2 inputs
            output = model(test_input, test_input)
        print(f"   {name}: OK")
    except Exception as e:
        print(f"   {name}: FAILED - {e}")
        raise

# 3. Verify dataset creation (mock)
print("\n3. Verifying dataset classes...")
print("   HardPairRankingDataset: Available")
print("   TripletRankingDataset: Available")

# 4. Verify runner functions exist
print("\n4. Verifying runner functions...")
runners = [
    "run_pairwise_experiment",
    "run_disentanglement_experiment",
    "run_triplet_experiment",
    "run_dimension_group_experiment",
]
for runner in runners:
    if callable(eval(runner)):
        print(f"   {runner}: OK")
    else:
        print(f"   {runner}: FAILED")
        raise ValueError(f"{runner} is not callable")

# 5. Verify config keys
print("\n5. Verifying config keys...")
required_keys = ["input_dim", "hidden_dim", "num_labels", "batch_size", "max_epochs"]
for key in required_keys:
    if key in APPROACH_B_CONFIG:
        print(f"   {key}: OK ({APPROACH_B_CONFIG[key]})")
    else:
        print(f"   {key}: MISSING")

print("\n" + "="*60)
print("Pre-flight verification PASSED!")
print("All models, imports, and configurations are ready.")
print("="*60)