# 08: Symbolic Encoder Comparison (S1 vs S2 vs S3)

Compare three symbolic encoder architectures:
- **S1**: Transformer on REMI-tokenized MIDI (TransformerSymbolicEncoder)
- **S2**: GNN on score graph with GATConv (GNNSymbolicEncoder)
- **S3**: 1D-CNN + Transformer on continuous MIDI features (ContinuousSymbolicEncoder)

Selection criteria (same as audio track):
1. Primary: Pairwise ranking accuracy
2. Tiebreak: R-squared on regression
3. Veto: Robustness drop > 15%

Additional symbolic-specific assessment:
- Score alignment on ASAP (performance-score MIDI cosine similarity)

---

## 1. Setup

In [None]:
import subprocess
import sys
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

In [None]:
# -- Thunder Compute Setup --
# 1. Clone repo
# !git clone <repo-url> /workspace/crescendai
# %cd /workspace/crescendai/model

# 2. Install dependencies
# !uv sync

# 3. Pull cached data from Google Drive via rclone
# !rclone sync gdrive:crescendai/model/data ./data --progress

# 4. Configure paths
from pathlib import Path

IS_REMOTE = os.environ.get('THUNDER_COMPUTE', False)
if IS_REMOTE:
    DATA_DIR = Path('/workspace/crescendai/model/data')
    CHECKPOINT_DIR = Path('/workspace/crescendai/model/checkpoints/model_improvement')
else:
    DATA_DIR = Path('../data')
    CHECKPOINT_DIR = Path('../checkpoints/model_improvement')

In [None]:
import json
import numpy as np
import torch
import matplotlib.pyplot as plt

sys.path.insert(0, 'src')

from model_improvement.symbolic_encoders import (
    TransformerSymbolicEncoder,
    GNNSymbolicEncoder,
    ContinuousSymbolicEncoder,
)
from model_improvement.tokenizer import PianoTokenizer, extract_continuous_features
from model_improvement.metrics import MetricsSuite, compute_robustness_metrics, format_comparison_table

## 2. Load Validation Data

Load PercePiano MIDI paths/labels and ASAP for score alignment.

In [None]:
cache_dir = DATA_DIR / 'percepiano_cache'

with open(cache_dir / 'labels.json') as f:
    labels = json.load(f)

with open(cache_dir / 'folds.json') as f:
    folds = json.load(f)

print(f'Loaded {len(labels)} labeled segments')
print(f'Number of folds: {len(folds)}')

In [None]:
# Load MIDI paths for symbolic encoders
midi_dir = DATA_DIR / 'percepiano_midi'
if midi_dir.exists():
    midi_paths = {p.stem: p for p in midi_dir.glob('*.mid')}
    print(f'Found {len(midi_paths)} MIDI files')
else:
    midi_paths = {}
    print(f'MIDI directory not found: {midi_dir}')

In [None]:
# Load ASAP data for score alignment assessment
asap_dir = DATA_DIR / 'asap_cache'
if asap_dir.exists():
    print(f'ASAP data found at {asap_dir}')
    asap_pairs = []
    for perf_midi in sorted(asap_dir.glob('**/performance*.mid')):
        score_midi = perf_midi.parent / perf_midi.name.replace('performance', 'score')
        if score_midi.exists():
            asap_pairs.append((perf_midi, score_midi))
    print(f'Found {len(asap_pairs)} performance-score MIDI pairs')
else:
    asap_pairs = []
    print(f'ASAP data not found at {asap_dir}')

## 3. Prepare Inputs for Each Encoder Type

In [None]:
# S1: Tokenize MIDIs
tokenizer = PianoTokenizer(max_seq_len=2048)
print(f'Tokenizer vocab size: {tokenizer.vocab_size}')

s1_tokens = {}
for key, midi_path in midi_paths.items():
    try:
        tokens = tokenizer.encode(midi_path)
        s1_tokens[key] = tokens
    except Exception as e:
        print(f'  Failed to tokenize {key}: {e}')

print(f'Tokenized {len(s1_tokens)} MIDI files for S1')

In [None]:
# S3: Extract continuous features
s3_features = {}
for key, midi_path in midi_paths.items():
    try:
        features = extract_continuous_features(midi_path, frame_rate=50)
        s3_features[key] = torch.tensor(features, dtype=torch.float32)
    except Exception as e:
        print(f'  Failed to extract features for {key}: {e}')

print(f'Extracted continuous features for {len(s3_features)} files (S3)')

## 4. Load Checkpoints

In [None]:
def load_best_checkpoint(model_class, model_dir, **model_kwargs):
    ckpt_dir = CHECKPOINT_DIR / model_dir
    ckpts = sorted(ckpt_dir.glob('*.ckpt'))
    if not ckpts:
        raise FileNotFoundError(f'No checkpoints found in {ckpt_dir}')
    best_ckpt = ckpts[-1]
    print(f'Loading {model_class.__name__} from {best_ckpt.name}')
    model = model_class.load_from_checkpoint(str(best_ckpt), **model_kwargs)
    model.eval()
    return model

models = {}
for name, cls, kwargs in [
    ('S1', TransformerSymbolicEncoder, {'stage': 'finetune'}),
    ('S2', GNNSymbolicEncoder, {'stage': 'finetune'}),
    ('S3', ContinuousSymbolicEncoder, {'stage': 'finetune'}),
]:
    try:
        models[name] = load_best_checkpoint(cls, name, **kwargs)
    except FileNotFoundError as e:
        print(f'{name}: {e}')

print(f'\nLoaded models: {list(models.keys())}')

## 5. Run MetricsSuite on Each Model

In [None]:
suite = MetricsSuite(ambiguous_threshold=0.05)

def evaluate_s1(model, val_keys, labels_dict, tokens_dict):
    results = {}
    model.eval()
    valid_keys = [k for k in val_keys if k in tokens_dict and k in labels_dict]
    
    with torch.no_grad():
        all_logits, all_la, all_lb = [], [], []
        for i, key_a in enumerate(valid_keys):
            for key_b in valid_keys[i+1:]:
                ids_a = torch.tensor(tokens_dict[key_a]).unsqueeze(0)
                ids_b = torch.tensor(tokens_dict[key_b]).unsqueeze(0)
                mask_a = torch.ones(1, ids_a.size(1), dtype=torch.bool)
                mask_b = torch.ones(1, ids_b.size(1), dtype=torch.bool)
                z_a = model.encode(ids_a, mask_a)
                z_b = model.encode(ids_b, mask_b)
                logits = model.compare(z_a, z_b)
                lab_a = torch.tensor(labels_dict[key_a][:19], dtype=torch.float32)
                lab_b = torch.tensor(labels_dict[key_b][:19], dtype=torch.float32)
                all_logits.append(logits)
                all_la.append(lab_a.unsqueeze(0))
                all_lb.append(lab_b.unsqueeze(0))
        
        if all_logits:
            pw = suite.pairwise_accuracy(torch.cat(all_logits), torch.cat(all_la), torch.cat(all_lb))
            results['pairwise'] = pw['overall']
            results['pairwise_detail'] = pw
        
        all_preds, all_targets = [], []
        for key in valid_keys:
            ids = torch.tensor(tokens_dict[key]).unsqueeze(0)
            mask = torch.ones(1, ids.size(1), dtype=torch.bool)
            out = model(ids, mask)
            all_preds.append(out['scores'])
            all_targets.append(torch.tensor(labels_dict[key][:19], dtype=torch.float32).unsqueeze(0))
        if all_preds:
            results['r2'] = suite.regression_r2(torch.cat(all_preds), torch.cat(all_targets))
    return results

def evaluate_s3(model, val_keys, labels_dict, features_dict):
    results = {}
    model.eval()
    valid_keys = [k for k in val_keys if k in features_dict and k in labels_dict]
    
    with torch.no_grad():
        all_logits, all_la, all_lb = [], [], []
        for i, key_a in enumerate(valid_keys):
            for key_b in valid_keys[i+1:]:
                fa = features_dict[key_a].unsqueeze(0)
                fb = features_dict[key_b].unsqueeze(0)
                ma = torch.ones(1, fa.size(1), dtype=torch.bool)
                mb = torch.ones(1, fb.size(1), dtype=torch.bool)
                z_a = model.encode(fa, ma)
                z_b = model.encode(fb, mb)
                logits = model.compare(z_a, z_b)
                lab_a = torch.tensor(labels_dict[key_a][:19], dtype=torch.float32)
                lab_b = torch.tensor(labels_dict[key_b][:19], dtype=torch.float32)
                all_logits.append(logits)
                all_la.append(lab_a.unsqueeze(0))
                all_lb.append(lab_b.unsqueeze(0))
        
        if all_logits:
            pw = suite.pairwise_accuracy(torch.cat(all_logits), torch.cat(all_la), torch.cat(all_lb))
            results['pairwise'] = pw['overall']
            results['pairwise_detail'] = pw
        
        all_preds, all_targets = [], []
        for key in valid_keys:
            feat = features_dict[key].unsqueeze(0)
            mask = torch.ones(1, feat.size(1), dtype=torch.bool)
            out = model(feat, mask)
            all_preds.append(out['scores'])
            all_targets.append(torch.tensor(labels_dict[key][:19], dtype=torch.float32).unsqueeze(0))
        if all_preds:
            results['r2'] = suite.regression_r2(torch.cat(all_preds), torch.cat(all_targets))
    return results

In [None]:
all_results = {}

for name, model in models.items():
    print(f'\nEvaluating {name}...')
    fold_metrics = []
    
    for fold_idx, fold in enumerate(folds):
        val_keys = fold['val']
        if name == 'S1':
            fold_res = evaluate_s1(model, val_keys, labels, s1_tokens)
        elif name == 'S2':
            fold_res = {'pairwise': 0.0, 'r2': 0.0}
            print(f'  Fold {fold_idx}: S2 requires graph construction pipeline (skipped)')
        elif name == 'S3':
            fold_res = evaluate_s3(model, val_keys, labels, s3_features)
        else:
            fold_res = {}
        fold_metrics.append(fold_res)
        pw = fold_res.get('pairwise', 'N/A')
        r2 = fold_res.get('r2', 'N/A')
        if isinstance(pw, float) and pw > 0:
            print(f'  Fold {fold_idx}: pairwise={pw:.4f}, r2={r2:.4f}')
    
    avg = {}
    for mk in ['pairwise', 'r2']:
        vals = [fm[mk] for fm in fold_metrics if mk in fm and isinstance(fm[mk], float)]
        if vals:
            avg[mk] = sum(vals) / len(vals)
    if 'pairwise_detail' in fold_metrics[-1]:
        avg['pairwise_detail'] = fold_metrics[-1]['pairwise_detail']
    all_results[name] = avg
    print(f'  Average: pairwise={avg.get("pairwise", "N/A")}, r2={avg.get("r2", "N/A")}')

## 6. Score Alignment Assessment (ASAP)

In [None]:
def evaluate_alignment(model, model_type, asap_pairs, tok=None, max_pairs=50):
    """Measure embedding similarity between performance and score MIDI."""
    similarities = []
    model.eval()
    with torch.no_grad():
        for perf_path, score_path in asap_pairs[:max_pairs]:
            try:
                if model_type == 'S1' and tok is not None:
                    perf_ids = torch.tensor(tok.encode(perf_path)).unsqueeze(0)
                    score_ids = torch.tensor(tok.encode(score_path)).unsqueeze(0)
                    z_perf = model.encode(perf_ids, torch.ones(1, perf_ids.size(1), dtype=torch.bool))
                    z_score = model.encode(score_ids, torch.ones(1, score_ids.size(1), dtype=torch.bool))
                elif model_type == 'S3':
                    pf = torch.tensor(extract_continuous_features(perf_path), dtype=torch.float32).unsqueeze(0)
                    sf = torch.tensor(extract_continuous_features(score_path), dtype=torch.float32).unsqueeze(0)
                    z_perf = model.encode(pf, torch.ones(1, pf.size(1), dtype=torch.bool))
                    z_score = model.encode(sf, torch.ones(1, sf.size(1), dtype=torch.bool))
                else:
                    continue
                cos_sim = torch.nn.functional.cosine_similarity(z_perf, z_score).item()
                similarities.append(cos_sim)
            except Exception:
                continue
    if similarities:
        return {'mean_cosine_sim': float(np.mean(similarities)), 'std': float(np.std(similarities)), 'n': len(similarities)}
    return None

In [None]:
alignment_results = {}
if asap_pairs:
    for name, model in models.items():
        if name == 'S2':
            print(f'{name}: Alignment requires graph builder (skipped)')
            continue
        print(f'Score alignment for {name}...')
        result = evaluate_alignment(model, name, asap_pairs, tok=tokenizer if name == 'S1' else None)
        if result:
            alignment_results[name] = result
            print(f'  cosine_sim={result["mean_cosine_sim"]:.4f} +/- {result["std"]:.4f} (n={result["n"]})')
else:
    print('No ASAP data available for alignment assessment.')

## 7. Comparison Table + Per-Dimension Breakdown

In [None]:
comparison = {}
for name in models:
    comparison[name] = {}
    if name in all_results:
        comparison[name]['pairwise'] = all_results[name].get('pairwise', 0.0)
        comparison[name]['r2'] = all_results[name].get('r2', 0.0)
    if name in alignment_results:
        comparison[name]['alignment'] = alignment_results[name]['mean_cosine_sim']

if comparison:
    print(format_comparison_table(comparison))
else:
    print('No results to compare yet.')

In [None]:
DIMENSION_NAMES = [
    'Correctness of pitch', 'Correctness of rhythm', 'Correctness of tempo',
    'Tone quality', 'Dynamic range', 'Articulation', 'Balance',
    'Rhythm. stability', 'Tempo stability', 'Phrasing', 'Stylistic accuracy',
    'Ornamentation', 'Pedaling', 'Expressiveness', 'Technical proficiency',
    'Memorization', 'Stage presence', 'Communication', 'Overall'
]

def plot_per_dimension_comparison(results_dict):
    fig, ax = plt.subplots(1, 1, figsize=(14, 6))
    n_dims = 19
    x = np.arange(n_dims)
    width = 0.25
    for i, (name, res) in enumerate(results_dict.items()):
        if 'pairwise_detail' not in res:
            continue
        per_dim = res['pairwise_detail']['per_dimension']
        values = [per_dim.get(d, 0.5) for d in range(n_dims)]
        ax.bar(x + i * width, values, width, label=name, alpha=0.8)
    ax.set_xticks(x + width)
    ax.set_xticklabels(DIMENSION_NAMES, rotation=45, ha='right', fontsize=8)
    ax.set_ylabel('Pairwise Accuracy')
    ax.set_title('Per-Dimension Pairwise Ranking Accuracy (Symbolic Encoders)')
    ax.legend()
    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    ax.set_ylim(0.3, 1.0)
    plt.tight_layout()
    plt.show()

if all_results:
    plot_per_dimension_comparison(all_results)
else:
    print('No results to plot yet.')

## 8. Winner Selection

1. **Primary**: Highest pairwise ranking accuracy
2. **Tiebreak**: Highest R-squared
3. **Veto**: Robustness score_drop_pct > 15%
4. **Bonus**: Score alignment (informational)

In [None]:
ROBUSTNESS_VETO_THRESHOLD = 15.0

def select_winner(comp):
    candidates = []
    for name, metrics in comp.items():
        drop_pct = metrics.get('score_drop_pct', 0.0)
        if drop_pct > ROBUSTNESS_VETO_THRESHOLD:
            print(f'{name}: VETOED (score_drop_pct={drop_pct:.1f}%)')
            continue
        candidates.append((name, metrics.get('pairwise', 0.0), metrics.get('r2', 0.0), metrics.get('alignment', 0.0)))
    if not candidates:
        print('All models vetoed!')
        return None
    candidates.sort(key=lambda x: (x[1], x[2]), reverse=True)
    winner = candidates[0][0]
    print(f'\nWinner: {winner}')
    for name, pw, r2, align in candidates:
        marker = ' <-- WINNER' if name == winner else ''
        print(f'  {name}: pairwise={pw:.4f}, r2={r2:.4f}, alignment={align:.4f}{marker}')
    return winner

if comparison:
    winner = select_winner(comparison)
else:
    print('Run training and evaluation first.')

In [None]:
def upload_checkpoint(local_path, remote_subdir):
    remote = f'gdrive:crescendai/model/checkpoints/model_improvement/{remote_subdir}'
    subprocess.run(['rclone', 'copy', str(local_path), remote, '--progress'], check=True)