# 07: Audio Encoder Comparison (A1 vs A2 vs A3)

Compare three MuQ domain adaptation strategies:
- **A1**: MuQ + LoRA multi-task (MuQLoRAModel)
- **A2**: Staged domain adaptation (MuQStagedModel)
- **A3**: Full unfreeze with gradual layer unfreezing (MuQFullUnfreezeModel)

Selection criteria:
1. Primary: Pairwise ranking accuracy
2. Tiebreak: R-squared on regression
3. Veto: Robustness drop > 15%

---

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

In [None]:
# -- Thunder Compute Setup --
# 1. Clone repo
# !git clone <repo-url> /workspace/crescendai
# %cd /workspace/crescendai/model

# 2. Install dependencies
# !uv sync

# 3. Pull cached data + embeddings from Google Drive via rclone
# !rclone sync gdrive:crescendai/model/data ./data --progress

# 4. Configure paths
from pathlib import Path

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
    CHECKPOINT_DIR = Path('/workspace/crescendai/model/checkpoints/model_improvement')
else:
    DATA_DIR = Path('../data')
    CHECKPOINT_DIR = Path('../checkpoints/model_improvement')

In [None]:
import json
import numpy as np
import torch
import matplotlib.pyplot as plt

sys.path.insert(0, 'src')

from model_improvement.audio_encoders import MuQLoRAModel, MuQStagedModel, MuQFullUnfreezeModel
from model_improvement.metrics import MetricsSuite, compute_robustness_metrics, format_comparison_table
from model_improvement.augmentation import AudioAugmentor

## 2. Load Validation Data

Use PercePiano 4-fold cross-validation splits and PSyllabus test set.

In [None]:
# Load PercePiano embeddings and labels
cache_dir = DATA_DIR / 'percepiano_cache'

with open(cache_dir / 'labels.json') as f:
    labels = json.load(f)

# Load pre-extracted MuQ embeddings
embeddings = torch.load(cache_dir / 'muq_embeddings.pt', map_location='cpu', weights_only=True)

print(f'Loaded {len(labels)} labeled segments')
print(f'Embedding keys: {len(embeddings)}')

In [None]:
# Load fold definitions (4-fold CV)
with open(cache_dir / 'folds.json') as f:
    folds = json.load(f)

print(f'Number of folds: {len(folds)}')
for i, fold in enumerate(folds):
    print(f'  Fold {i}: {len(fold["train"])} train, {len(fold["val"])} val')

## 3. Load Checkpoints

Load the best checkpoint from each model variant (across folds).

In [None]:
def load_best_checkpoint(model_class, model_dir, **model_kwargs):
    """Load best checkpoint from a model directory (picks lowest val_loss)."""
    ckpt_dir = CHECKPOINT_DIR / model_dir
    ckpts = sorted(ckpt_dir.glob('*.ckpt'))
    if not ckpts:
        raise FileNotFoundError(f'No checkpoints found in {ckpt_dir}')
    # Load the last checkpoint (assumed best after early stopping)
    best_ckpt = ckpts[-1]
    print(f'Loading {model_class.__name__} from {best_ckpt.name}')
    model = model_class.load_from_checkpoint(str(best_ckpt), **model_kwargs)
    model.set_strict_loading(False)
    return model

models = {}
try:
    models['A1'] = load_best_checkpoint(MuQLoRAModel, 'A1')
except FileNotFoundError as e:
    print(f'A1: {e}')

try:
    models['A2'] = load_best_checkpoint(MuQStagedModel, 'A2', stage='supervised')
except FileNotFoundError as e:
    print(f'A2: {e}')

try:
    models['A3'] = load_best_checkpoint(MuQFullUnfreezeModel, 'A3')
except FileNotFoundError as e:
    print(f'A3: {e}')

print(f'\nLoaded models: {list(models.keys())}')

## 4. Run MetricsSuite on Each Model

In [None]:
suite = MetricsSuite(ambiguous_threshold=0.05)

def build_val_pairs(val_keys, labels_dict, embeddings_dict):
    """Generate all pairwise comparisons from validation keys."""
    pairs = []
    val_keys = sorted(val_keys)
    for i, key_a in enumerate(val_keys):
        for key_b in val_keys[i+1:]:
            lab_a = torch.tensor(labels_dict[key_a][:19], dtype=torch.float32)
            lab_b = torch.tensor(labels_dict[key_b][:19], dtype=torch.float32)
            pairs.append((key_a, key_b, lab_a, lab_b))
    return pairs

def evaluate_model_on_fold(model, fold, labels_dict, embeddings_dict):
    """Run full metric suite on a model for one fold."""
    val_keys = fold['val']
    val_pairs = build_val_pairs(val_keys, labels_dict, embeddings_dict)
    
    results = {}
    model.eval()
    
    with torch.no_grad():
        # Pairwise ranking accuracy
        all_logits = []
        all_labels_a = []
        all_labels_b = []
        
        for key_a, key_b, lab_a, lab_b in val_pairs:
            emb_a = embeddings_dict[key_a].unsqueeze(0)
            emb_b = embeddings_dict[key_b].unsqueeze(0)
            out = model(emb_a, emb_b)
            all_logits.append(out['ranking_logits'])
            all_labels_a.append(lab_a.unsqueeze(0))
            all_labels_b.append(lab_b.unsqueeze(0))
        
        if all_logits:
            logits = torch.cat(all_logits)
            la = torch.cat(all_labels_a)
            lb = torch.cat(all_labels_b)
            pw = suite.pairwise_accuracy(logits, la, lb)
            results['pairwise'] = pw['overall']
            results['pairwise_detail'] = pw
        
        # Regression R2
        all_preds = []
        all_targets = []
        for key in val_keys:
            if key in embeddings_dict:
                emb = embeddings_dict[key].unsqueeze(0)
                pred = model.predict_scores(emb)
                target = torch.tensor(labels_dict[key][:19], dtype=torch.float32).unsqueeze(0)
                all_preds.append(pred)
                all_targets.append(target)
        
        if all_preds:
            preds = torch.cat(all_preds)
            targets = torch.cat(all_targets)
            results['r2'] = suite.regression_r2(preds, targets)
    
    return results

In [None]:
# Aggregate across all folds
all_results = {}

for name, model in models.items():
    print(f'\nEvaluating {name}...')
    fold_metrics = []
    for fold_idx, fold in enumerate(folds):
        fold_res = evaluate_model_on_fold(model, fold, labels, embeddings)
        fold_metrics.append(fold_res)
        print(f'  Fold {fold_idx}: pairwise={fold_res.get("pairwise", "N/A"):.4f}, r2={fold_res.get("r2", "N/A"):.4f}')
    
    # Average across folds
    avg = {}
    for metric_key in ['pairwise', 'r2']:
        values = [fm[metric_key] for fm in fold_metrics if metric_key in fm]
        if values:
            avg[metric_key] = sum(values) / len(values)
    
    # Keep last fold's per-dimension detail for visualization
    if 'pairwise_detail' in fold_metrics[-1]:
        avg['pairwise_detail'] = fold_metrics[-1]['pairwise_detail']
    
    all_results[name] = avg
    print(f'  Average: pairwise={avg.get("pairwise", "N/A"):.4f}, r2={avg.get("r2", "N/A"):.4f}')

## 5. Robustness Metrics (Augmented Test Set)

In [None]:
def evaluate_robustness(model, val_keys, labels_dict, embeddings_dict, noise_std=0.05):
    """Compare clean vs augmented predictions for robustness."""
    clean_scores = []
    aug_scores = []
    
    model.eval()
    with torch.no_grad():
        for key in val_keys:
            if key not in embeddings_dict:
                continue
            emb = embeddings_dict[key]
            
            # Clean prediction
            pred_clean = model.predict_scores(emb.unsqueeze(0))
            clean_scores.append(pred_clean)
            
            # Augmented prediction (Gaussian noise on embeddings as proxy)
            emb_aug = emb + torch.randn_like(emb) * noise_std
            pred_aug = model.predict_scores(emb_aug.unsqueeze(0))
            aug_scores.append(pred_aug)
    
    clean = torch.cat(clean_scores)
    augmented = torch.cat(aug_scores)
    return compute_robustness_metrics(clean, augmented)

In [None]:
# Run robustness evaluation on the last fold's val set
robustness_results = {}

if folds and models:
    val_keys = folds[-1]['val']
    for name, model in models.items():
        print(f'Robustness check for {name}...')
        rob = evaluate_robustness(model, val_keys, labels, embeddings)
        robustness_results[name] = rob
        print(f'  pearson_r={rob["pearson_r"]:.4f}, score_drop_pct={rob["score_drop_pct"]:.1f}%')

## 6. Comparison Table

In [None]:
# Build unified results dict for comparison
comparison = {}
for name in models:
    comparison[name] = {}
    if name in all_results:
        comparison[name]['pairwise'] = all_results[name].get('pairwise', 0.0)
        comparison[name]['r2'] = all_results[name].get('r2', 0.0)
    if name in robustness_results:
        comparison[name]['robustness'] = robustness_results[name]['pearson_r']
        comparison[name]['score_drop_pct'] = robustness_results[name]['score_drop_pct']

if comparison:
    table = format_comparison_table(comparison)
    print(table)
else:
    print('No results to compare yet -- run training first.')

## 7. Per-Dimension Breakdown

In [None]:
DIMENSION_NAMES = [
    'Correctness of pitch', 'Correctness of rhythm', 'Correctness of tempo',
    'Tone quality', 'Dynamic range', 'Articulation', 'Balance',
    'Rhythm. stability', 'Tempo stability', 'Phrasing', 'Stylistic accuracy',
    'Ornamentation', 'Pedaling', 'Expressiveness', 'Technical proficiency',
    'Memorization', 'Stage presence', 'Communication', 'Overall'
]

def plot_per_dimension_comparison(results_dict):
    """Bar chart comparing per-dimension accuracy across models."""
    fig, ax = plt.subplots(1, 1, figsize=(14, 6))
    
    n_dims = 19
    x = np.arange(n_dims)
    width = 0.25
    
    for i, (name, res) in enumerate(results_dict.items()):
        if 'pairwise_detail' not in res:
            continue
        per_dim = res['pairwise_detail']['per_dimension']
        values = [per_dim.get(d, 0.5) for d in range(n_dims)]
        ax.bar(x + i * width, values, width, label=name, alpha=0.8)
    
    ax.set_xticks(x + width)
    ax.set_xticklabels(DIMENSION_NAMES, rotation=45, ha='right', fontsize=8)
    ax.set_ylabel('Pairwise Accuracy')
    ax.set_title('Per-Dimension Pairwise Ranking Accuracy')
    ax.legend()
    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    ax.set_ylim(0.3, 1.0)
    plt.tight_layout()
    plt.show()

if all_results:
    plot_per_dimension_comparison(all_results)
else:
    print('No results to plot yet.')

## 8. Winner Selection

Selection protocol:
1. **Primary**: Highest pairwise ranking accuracy
2. **Tiebreak**: Highest R-squared on regression
3. **Veto**: Robustness score_drop_pct > 15% disqualifies the model

In [None]:
ROBUSTNESS_VETO_THRESHOLD = 15.0  # percent

def select_winner(comparison_results):
    """Select best audio encoder based on selection criteria."""
    candidates = []
    
    for name, metrics in comparison_results.items():
        # Veto check
        drop_pct = metrics.get('score_drop_pct', 0.0)
        if drop_pct > ROBUSTNESS_VETO_THRESHOLD:
            print(f'{name}: VETOED (score_drop_pct={drop_pct:.1f}% > {ROBUSTNESS_VETO_THRESHOLD}%)')
            continue
        
        candidates.append((
            name,
            metrics.get('pairwise', 0.0),
            metrics.get('r2', 0.0),
        ))
    
    if not candidates:
        print('All models vetoed!')
        return None
    
    # Sort by pairwise (desc), then r2 (desc)
    candidates.sort(key=lambda x: (x[1], x[2]), reverse=True)
    
    winner = candidates[0][0]
    print(f'\nWinner: {winner}')
    for name, pw, r2 in candidates:
        marker = ' <-- WINNER' if name == winner else ''
        print(f'  {name}: pairwise={pw:.4f}, r2={r2:.4f}{marker}')
    
    return winner

if comparison:
    winner = select_winner(comparison)
else:
    print('Run training and evaluation first.')

In [None]:
def upload_checkpoint(local_path, remote_subdir):
    """Sync checkpoint to Google Drive after results are finalized."""
    remote = f'gdrive:crescendai/model/checkpoints/model_improvement/{remote_subdir}'
    subprocess.run(['rclone', 'copy', str(local_path), remote, '--progress'], check=True)