# Embedding Model Ablation Study

This notebook systematically evaluates whether the current embedding model (`all-MiniLM-L6-v2`) is leaving prediction performance on the table, and quantifies the impact of alternative encoders on the VIF Critic.

**Motivation:**  
The embedding model is the first link in the VIF inference chain:

```
Journal Entry → [Embedding Model] → State Vector → [Critic MLP] → Alignment Scores
```

Research shows `all-MiniLM-L6-v2` has known emotional blindness (e.g. "I love reading" vs "I hate reading" → 0.59 cosine similarity). Since Twinkl detects value alignment through emotional nuance in journal entries, this blindness may directly constrain downstream performance.

**Experiments:**
1. **Embedding Quality Probe** — Do the encoders differentiate emotionally opposite value statements?
2. **End-to-End Ablation** — Train identical CriticMLP models with different encoders, compare on the same test set
3. **Matryoshka Dimension Search** — For models supporting it, find the optimal embedding dimension
4. **Latency Benchmark** — Measure encoding speed for practical deployment

**Models Under Test:**

| Model | Dims | Max Tokens | MTEB | Notes |
|-------|------|------------|------|-------|
| `all-MiniLM-L6-v2` | 384 | 256 | ~56.3 | Current baseline |
| `all-mpnet-base-v2` | 768 | 512 | ~60.0 | Easy upgrade, same library |
| `nomic-ai/nomic-embed-text-v1.5` | 768 | 8192 | ~62.3 | Long-context, Matryoshka support |

**Contents:**
1. Setup & Imports
2. Embedding Quality Probe (Emotional Discrimination)
3. Data Loading (Shared Splits)
4. End-to-End Encoder Ablation
5. Side-by-Side Results
6. Per-Dimension Deep Dive
7. Matryoshka Dimension Search
8. Latency Benchmark
9. Conclusions & Recommendations

In [None]:
# Setup
import os
import sys
import time
from pathlib import Path

# Walk up to find project root (contains src/ and pyproject.toml)
_dir = Path.cwd()
while _dir != _dir.parent:
    if (_dir / "src").is_dir() and (_dir / "pyproject.toml").is_file():
        os.chdir(_dir)
        break
    _dir = _dir.parent
sys.path.insert(0, ".")

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from scipy.spatial.distance import cosine as cosine_dist
from scipy import stats
import polars as pl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Working directory: {os.getcwd()}")
print(f"Using device: {device}")
print(f"Random seed: {SEED}")

## 1. Embedding Quality Probe

Before training any models, test whether each encoder can distinguish emotionally opposite statements about the same topic — the core signal Twinkl needs for value detection.

**Test Design:**
- **Pairs of value-laden statements** where the topic is identical but the value alignment flips
- A good encoder should give **low** cosine similarity to opposite pairs and **high** similarity to semantically equivalent pairs
- Covers each Schwartz value dimension with realistic journal-style text

In [None]:
# Value-opposite pairs: same topic, flipped alignment
# Format: (positive_alignment, negative_alignment, value_dimension)
PROBE_PAIRS = [
    # Known failing case from literature
    ("I love reading books", "I hate reading books", "sentiment_control"),
    # Benevolence vs Power
    ("I helped my colleague succeed at their presentation",
     "I ensured my colleague knew their place during the presentation",
     "benevolence_vs_power"),
    # Conformity vs Self-Direction
    ("I followed the team's decision even though I disagreed",
     "I pushed back on the team's decision and went my own way",
     "conformity_vs_self_direction"),
    # Security vs Stimulation
    ("I stayed with the familiar routine because it felt safe",
     "I tried something completely new and unpredictable today",
     "security_vs_stimulation"),
    # Tradition vs Self-Direction
    ("I honored the customs my family has always followed",
     "I broke with tradition and chose my own path forward",
     "tradition_vs_self_direction"),
    # Achievement (positive vs negative)
    ("I pushed myself hard and achieved my ambitious goal today",
     "I let the opportunity pass because the effort wasn't worth it",
     "achievement"),
    # Universalism (positive vs negative)
    ("I volunteered for the environmental cleanup because every ecosystem matters",
     "I skipped the cleanup because it's not my problem to solve",
     "universalism"),
    # Hedonism (positive vs negative)
    ("I indulged in a long luxurious evening doing exactly what I enjoy",
     "I forced myself through another joyless evening of obligations",
     "hedonism"),
    # Nuanced benevolence
    ("Listening to my friend's struggles reminded me how much I care about their wellbeing",
     "Listening to my friend's struggles reminded me how draining other people's problems are",
     "benevolence_nuance"),
    # Power (positive vs negative)
    ("Leading the meeting felt natural — I thrive when I'm in control",
     "Leading the meeting felt exhausting — I'd rather someone else take charge",
     "power_nuance"),
]

# Semantic equivalence pairs (should have HIGH similarity)
EQUIVALENT_PAIRS = [
    ("I helped my colleague with their work",
     "I assisted my coworker on their project",
     "benevolence_equivalent"),
    ("I took a risk and tried something new",
     "I stepped outside my comfort zone today",
     "stimulation_equivalent"),
    ("I followed the rules even when it was hard",
     "I stayed compliant with the guidelines despite wanting to deviate",
     "conformity_equivalent"),
]

print(f"Opposite pairs: {len(PROBE_PAIRS)}")
print(f"Equivalent pairs: {len(EQUIVALENT_PAIRS)}")

In [None]:
from sentence_transformers import SentenceTransformer

# Define encoders to test
ENCODER_CONFIGS = {
    "MiniLM-L6 (baseline)": {
        "model_name": "all-MiniLM-L6-v2",
        "prefix": "",  # No prefix needed
    },
    "MPNet-base": {
        "model_name": "all-mpnet-base-v2",
        "prefix": "",
    },
    "Nomic-v1.5": {
        "model_name": "nomic-ai/nomic-embed-text-v1.5",
        "prefix": "search_document: ",  # Nomic requires task prefix
    },
}

# Load all encoder models
encoder_models = {}
for name, cfg in ENCODER_CONFIGS.items():
    print(f"Loading {name} ({cfg['model_name']})...")
    model = SentenceTransformer(cfg["model_name"], trust_remote_code=True)
    encoder_models[name] = model
    dim = model.get_sentence_embedding_dimension()
    print(f"  Embedding dim: {dim}")

print("\nAll encoders loaded.")

In [None]:
def compute_cosine_similarity(emb1, emb2):
    """Cosine similarity between two embeddings."""
    return 1 - cosine_dist(emb1, emb2)


def run_probe(encoder_name, model, pairs, prefix=""):
    """Run probe pairs through an encoder and return similarities."""
    results = []
    for text_a, text_b, label in pairs:
        emb_a = model.encode(prefix + text_a)
        emb_b = model.encode(prefix + text_b)
        sim = compute_cosine_similarity(emb_a, emb_b)
        results.append({"label": label, "similarity": sim})
    return results


# Run probes for all encoders
probe_results = {}  # encoder_name -> {"opposite": [...], "equivalent": [...]}

for name, cfg in ENCODER_CONFIGS.items():
    model = encoder_models[name]
    prefix = cfg["prefix"]

    opposite = run_probe(name, model, PROBE_PAIRS, prefix)
    equivalent = run_probe(name, model, EQUIVALENT_PAIRS, prefix)

    probe_results[name] = {"opposite": opposite, "equivalent": equivalent}

# Display results table
print("=" * 90)
print("EMBEDDING QUALITY PROBE: Emotionally Opposite Pairs")
print("(Lower similarity = better discrimination)")
print("=" * 90)
print(f"{'Pair':<30}", end="")
for name in ENCODER_CONFIGS:
    print(f"{name:>20}", end="")
print()
print("-" * 90)

for i, (text_a, text_b, label) in enumerate(PROBE_PAIRS):
    print(f"{label:<30}", end="")
    for name in ENCODER_CONFIGS:
        sim = probe_results[name]["opposite"][i]["similarity"]
        print(f"{sim:>20.3f}", end="")
    print()

print("-" * 90)
# Mean opposite similarity
print(f"{'MEAN (opposite):':<30}", end="")
for name in ENCODER_CONFIGS:
    mean_sim = np.mean([r["similarity"] for r in probe_results[name]["opposite"]])
    print(f"{mean_sim:>20.3f}", end="")
print()

print()
print("SEMANTIC EQUIVALENCE PAIRS (Higher = better)")
print("-" * 90)
for i, (text_a, text_b, label) in enumerate(EQUIVALENT_PAIRS):
    print(f"{label:<30}", end="")
    for name in ENCODER_CONFIGS:
        sim = probe_results[name]["equivalent"][i]["similarity"]
        print(f"{sim:>20.3f}", end="")
    print()

print("-" * 90)
print(f"{'MEAN (equivalent):':<30}", end="")
for name in ENCODER_CONFIGS:
    mean_sim = np.mean([r["similarity"] for r in probe_results[name]["equivalent"]])
    print(f"{mean_sim:>20.3f}", end="")
print()

# Discrimination gap
print()
print(f"{'DISCRIMINATION GAP:':<30}", end="")
for name in ENCODER_CONFIGS:
    mean_opp = np.mean([r["similarity"] for r in probe_results[name]["opposite"]])
    mean_eq = np.mean([r["similarity"] for r in probe_results[name]["equivalent"]])
    gap = mean_eq - mean_opp
    print(f"{gap:>20.3f}", end="")
print()
print("(Higher gap = better at distinguishing opposite vs equivalent meaning)")
print("=" * 90)

In [None]:
# Visualize probe results
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

encoder_names = list(ENCODER_CONFIGS.keys())
x = np.arange(len(PROBE_PAIRS))
width = 0.25

# Opposite pairs
for i, name in enumerate(encoder_names):
    sims = [r["similarity"] for r in probe_results[name]["opposite"]]
    axes[0].bar(x + i * width, sims, width, label=name, alpha=0.85)

axes[0].set_ylabel("Cosine Similarity")
axes[0].set_title("Opposite Pairs (Lower = Better Discrimination)")
axes[0].set_xticks(x + width)
pair_labels = [p[2] for p in PROBE_PAIRS]
axes[0].set_xticklabels(pair_labels, rotation=45, ha="right", fontsize=7)
axes[0].legend(fontsize=8)
axes[0].axhline(y=0.5, color="red", linestyle="--", alpha=0.3, label="chance")
axes[0].set_ylim(0, 1)
axes[0].grid(True, alpha=0.2)

# Summary: mean opposite vs mean equivalent
bar_width = 0.3
x_summary = np.arange(len(encoder_names))

mean_opp = [np.mean([r["similarity"] for r in probe_results[n]["opposite"]]) for n in encoder_names]
mean_eq = [np.mean([r["similarity"] for r in probe_results[n]["equivalent"]]) for n in encoder_names]

bars1 = axes[1].bar(x_summary - bar_width/2, mean_opp, bar_width, label="Opposite pairs", color="#e74c3c", alpha=0.85)
bars2 = axes[1].bar(x_summary + bar_width/2, mean_eq, bar_width, label="Equivalent pairs", color="#2ecc71", alpha=0.85)

axes[1].set_ylabel("Mean Cosine Similarity")
axes[1].set_title("Discrimination Summary\n(Bigger gap between bars = better)")
axes[1].set_xticks(x_summary)
axes[1].set_xticklabels(encoder_names, fontsize=9)
axes[1].legend()
axes[1].set_ylim(0, 1)
axes[1].grid(True, alpha=0.2)

# Add gap annotations
for i, (opp, eq) in enumerate(zip(mean_opp, mean_eq)):
    gap = eq - opp
    mid = (opp + eq) / 2
    axes[1].annotate(f"gap={gap:.3f}", xy=(i, mid), fontsize=8, ha="center",
                     bbox=dict(boxstyle="round,pad=0.2", facecolor="yellow", alpha=0.7))

plt.tight_layout()
plt.show()

## 2. Data Loading

Load the dataset and create identical train/val/test splits. The splits are by persona (same seed = same split for all encoders) to ensure a fair comparison.

In [None]:
from src.vif.dataset import load_all_data, split_by_persona
from src.models.judge import SCHWARTZ_VALUE_ORDER

labels_df, entries_df = load_all_data()
train_df, val_df, test_df = split_by_persona(labels_df, entries_df, seed=SEED)

print(f"Labels: {labels_df.shape[0]} entries, {labels_df.select('persona_id').n_unique()} personas")
print(f"Entries: {entries_df.shape[0]} entries")
print()
print(f"Train: {len(train_df)} entries ({train_df.select('persona_id').n_unique()} personas)")
print(f"Val:   {len(val_df)} entries ({val_df.select('persona_id').n_unique()} personas)")
print(f"Test:  {len(test_df)} entries ({test_df.select('persona_id').n_unique()} personas)")

## 3. End-to-End Encoder Ablation

For each encoder:
1. Create a new `SBERTEncoder` → `StateEncoder` (embedding dim varies, so state dim varies)
2. Create new `VIFDataset` instances (re-caches embeddings for this encoder)
3. Train a fresh `CriticMLP` with identical hyperparameters
4. Evaluate on the test set with MC Dropout uncertainty

**Controlled variables:** Data splits, hidden_dim, dropout, learning rate, scheduler, early stopping, batch size, random seed.  
**Independent variable:** Encoder model (and therefore embedding dimension and state dimension).

In [None]:
from src.vif.encoders import SBERTEncoder
from src.vif.state_encoder import StateEncoder
from src.vif.critic import CriticMLP
from src.vif.dataset import VIFDataset
from src.vif.eval import evaluate_with_uncertainty, format_results_table

# Shared hyperparameters
WINDOW_SIZE = 3
EMA_ALPHA = 0.3
HIDDEN_DIM = 256
DROPOUT = 0.2
BATCH_SIZE = 16
N_EPOCHS = 100
LR = 0.001
WEIGHT_DECAY = 0.01
EARLY_STOP_PATIENCE = 20
N_MC_SAMPLES = 50

print("Ablation hyperparameters:")
print(f"  Window size:   {WINDOW_SIZE}")
print(f"  EMA alpha:     {EMA_ALPHA}")
print(f"  Hidden dim:    {HIDDEN_DIM}")
print(f"  Dropout:       {DROPOUT}")
print(f"  Batch size:    {BATCH_SIZE}")
print(f"  Max epochs:    {N_EPOCHS}")
print(f"  Learning rate: {LR}")
print(f"  Weight decay:  {WEIGHT_DECAY}")
print(f"  Early stop:    {EARLY_STOP_PATIENCE} epochs")
print(f"  MC samples:    {N_MC_SAMPLES}")

In [None]:
def train_with_encoder(encoder_name, encoder_config, train_df, val_df, test_df):
    """Train a CriticMLP with a specific encoder and evaluate.

    Returns dict with training history, test results, and metadata.
    """
    print(f"\n{'=' * 70}")
    print(f"ENCODER: {encoder_name}")
    print(f"{'=' * 70}")

    # Reset seeds for each encoder
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # Create encoder and state encoder
    model_name = encoder_config["model_name"]
    print(f"Loading encoder: {model_name}")
    text_encoder = SBERTEncoder(model_name)
    state_encoder = StateEncoder(
        text_encoder, window_size=WINDOW_SIZE, ema_alpha=EMA_ALPHA
    )
    emb_dim = text_encoder.embedding_dim
    state_dim = state_encoder.state_dim
    print(f"  Embedding dim: {emb_dim}")
    print(f"  State dim:     {state_dim}")

    # Create datasets (caches embeddings)
    print("Creating datasets (caching embeddings)...")
    t0 = time.time()
    train_dataset = VIFDataset(train_df, state_encoder, cache_embeddings=True)
    val_dataset = VIFDataset(val_df, state_encoder, cache_embeddings=True)
    test_dataset = VIFDataset(test_df, state_encoder, cache_embeddings=True)
    cache_time = time.time() - t0
    print(f"  Embedding cache time: {cache_time:.1f}s")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Create model
    model = CriticMLP(
        input_dim=state_dim,
        hidden_dim=HIDDEN_DIM,
        dropout=DROPOUT,
    )
    model.to(device)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"  Model params:  {n_params:,}")

    # Training setup
    criterion = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10, min_lr=1e-5
    )

    history = {"train_loss": [], "val_loss": [], "lr": []}
    best_val_loss = float("inf")
    best_model_state = None
    patience_counter = 0

    # Training loop
    print("\nTraining...")
    t0 = time.time()
    for epoch in range(N_EPOCHS):
        # Train
        model.train()
        train_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            pred = model(batch_x)
            loss = criterion(pred, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # Validate
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                pred = model(batch_x)
                val_loss += criterion(pred, batch_y).item()
        val_loss /= len(val_loader)

        # Scheduler
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["lr"].append(current_lr)

        # Early stopping
        if val_loss < best_val_loss - 0.001:
            best_val_loss = val_loss
            best_model_state = {
                k: v.cpu().clone() for k, v in model.state_dict().items()
            }
            patience_counter = 0
            print(
                f"  Epoch {epoch + 1:3d}: train={train_loss:.4f}, "
                f"val={val_loss:.4f}, lr={current_lr:.6f} [BEST]"
            )
        else:
            patience_counter += 1
            if epoch % 10 == 0:
                print(
                    f"  Epoch {epoch + 1:3d}: train={train_loss:.4f}, "
                    f"val={val_loss:.4f}, lr={current_lr:.6f}"
                )

        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"  Early stopping at epoch {epoch + 1}")
            break

    train_time = time.time() - t0
    print(f"  Training time: {train_time:.1f}s")

    # Load best model and evaluate
    model.load_state_dict(best_model_state)
    model.to(device)

    print("\nEvaluating with MC Dropout...")
    test_results = evaluate_with_uncertainty(
        model, test_loader, n_mc_samples=N_MC_SAMPLES, device=device
    )
    print(format_results_table(test_results))

    return {
        "encoder_name": encoder_name,
        "model_name": model_name,
        "embedding_dim": emb_dim,
        "state_dim": state_dim,
        "n_params": n_params,
        "best_val_loss": best_val_loss,
        "history": history,
        "test_results": test_results,
        "cache_time": cache_time,
        "train_time": train_time,
        "epochs_trained": epoch + 1,
    }

In [None]:
# Run ablation for each encoder
ablation_results = {}

for name, cfg in ENCODER_CONFIGS.items():
    result = train_with_encoder(name, cfg, train_df, val_df, test_df)
    ablation_results[name] = result

print("\n" + "=" * 70)
print("ALL ENCODER ABLATIONS COMPLETE")
print("=" * 70)

## 4. Side-by-Side Results

In [None]:
# Side-by-side comparison table
encoder_names = list(ablation_results.keys())

print("\n" + "=" * 100)
print("ENCODER ABLATION: SIDE-BY-SIDE COMPARISON")
print("=" * 100)

header = f"{'Metric':<30}"
for name in encoder_names:
    header += f"{name:>22}"
header += f"{'Winner':>12}"
print(header)
print("-" * 100)


def find_winner(values_and_names, higher_is_better=True):
    valid = [(v, n) for v, n in values_and_names if v is not None]
    if not valid:
        return "N/A"
    if higher_is_better:
        return max(valid, key=lambda x: x[0])[1]
    return min(valid, key=lambda x: x[0])[1]


rows = [
    ("Embedding Dim", [str(ablation_results[n]["embedding_dim"]) for n in encoder_names], None),
    ("State Dim", [str(ablation_results[n]["state_dim"]) for n in encoder_names], None),
    ("Model Params", [f"{ablation_results[n]['n_params']:,}" for n in encoder_names], None),
    ("Best Val Loss",
     [f"{ablation_results[n]['best_val_loss']:.4f}" for n in encoder_names],
     find_winner([(ablation_results[n]["best_val_loss"], n) for n in encoder_names], higher_is_better=False)),
    ("Test MSE",
     [f"{ablation_results[n]['test_results']['mse_mean']:.4f}" for n in encoder_names],
     find_winner([(ablation_results[n]["test_results"]["mse_mean"], n) for n in encoder_names], higher_is_better=False)),
    ("Test Spearman",
     [f"{ablation_results[n]['test_results']['spearman_mean']:.4f}" for n in encoder_names],
     find_winner([(ablation_results[n]["test_results"]["spearman_mean"], n) for n in encoder_names], higher_is_better=True)),
    ("Test Accuracy",
     [f"{ablation_results[n]['test_results']['accuracy_mean']:.2%}" for n in encoder_names],
     find_winner([(ablation_results[n]["test_results"]["accuracy_mean"], n) for n in encoder_names], higher_is_better=True)),
    ("Error-Uncertainty Corr",
     [f"{ablation_results[n]['test_results']['calibration']['error_uncertainty_correlation']:.3f}" for n in encoder_names],
     find_winner([(ablation_results[n]["test_results"]["calibration"]["error_uncertainty_correlation"], n) for n in encoder_names], higher_is_better=True)),
    ("Embedding Cache Time (s)",
     [f"{ablation_results[n]['cache_time']:.1f}" for n in encoder_names],
     find_winner([(ablation_results[n]["cache_time"], n) for n in encoder_names], higher_is_better=False)),
    ("Training Time (s)",
     [f"{ablation_results[n]['train_time']:.1f}" for n in encoder_names],
     find_winner([(ablation_results[n]["train_time"], n) for n in encoder_names], higher_is_better=False)),
    ("Epochs Trained",
     [str(ablation_results[n]["epochs_trained"]) for n in encoder_names],
     None),
]

for metric, values, winner in rows:
    line = f"{metric:<30}"
    for v in values:
        line += f"{v:>22}"
    if winner:
        line += f"{winner:>12}"
    print(line)

print("=" * 100)

In [None]:
# Training curves comparison
fig, axes = plt.subplots(1, len(encoder_names), figsize=(6 * len(encoder_names), 4))

if len(encoder_names) == 1:
    axes = [axes]

for ax, name in zip(axes, encoder_names):
    h = ablation_results[name]["history"]
    ax.plot(h["train_loss"], label="Train", alpha=0.7)
    ax.plot(h["val_loss"], label="Val", alpha=0.7)
    ax.set_title(f"{name}\n(val={ablation_results[name]['best_val_loss']:.4f})")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("MSE Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle("Training Curves by Encoder", y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

## 5. Per-Dimension Deep Dive

Compare how each encoder performs on individual Schwartz value dimensions.

In [None]:
# Per-dimension comparison tables
for metric_key, metric_label, higher_better in [
    ("mse_per_dim", "MSE (lower is better)", False),
    ("spearman_per_dim", "Spearman Correlation (higher is better)", True),
    ("accuracy_per_dim", "Accuracy (higher is better)", True),
]:
    print(f"\n{'=' * 90}")
    print(f"{metric_label}")
    print(f"{'=' * 90}")

    header = f"{'Dimension':<20}"
    for name in encoder_names:
        header += f"{name:>22}"
    header += f"{'Winner':>12}"
    print(header)
    print("-" * 90)

    for dim_name in SCHWARTZ_VALUE_ORDER:
        line = f"{dim_name:<20}"
        vals = []
        for name in encoder_names:
            v = ablation_results[name]["test_results"][metric_key][dim_name]
            vals.append((v, name))
            if metric_key == "accuracy_per_dim":
                line += f"{v:>21.2%} "
            else:
                v_str = f"{v:.4f}" if not np.isnan(v) else "N/A"
                line += f"{v_str:>22}"

        valid_vals = [(v, n) for v, n in vals if not np.isnan(v)]
        if valid_vals:
            winner = find_winner(valid_vals, higher_is_better=higher_better)
            line += f"{winner:>12}"
        print(line)

    # Mean row
    mean_key = metric_key.replace("_per_dim", "_mean")
    line = f"{'MEAN':<20}"
    vals = []
    for name in encoder_names:
        v = ablation_results[name]["test_results"][mean_key]
        vals.append((v, name))
        if metric_key == "accuracy_per_dim":
            line += f"{v:>21.2%} "
        else:
            line += f"{v:>22.4f}"
    winner = find_winner(vals, higher_is_better=higher_better)
    line += f"{winner:>12}"
    print(line)
    print("=" * 90)

In [None]:
# Per-dimension visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

dim_names = SCHWARTZ_VALUE_ORDER
x = np.arange(len(dim_names))
width = 0.25
colors = ["#3498db", "#e74c3c", "#2ecc71"]

# MSE per dimension
for i, name in enumerate(encoder_names):
    vals = [ablation_results[name]["test_results"]["mse_per_dim"][d] for d in dim_names]
    axes[0].barh(x + i * width, vals, width, label=name, color=colors[i], alpha=0.85)
axes[0].set_xlabel("MSE")
axes[0].set_title("MSE per Dimension (lower = better)")
axes[0].set_yticks(x + width)
axes[0].set_yticklabels(dim_names, fontsize=8)
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.2, axis="x")

# Spearman per dimension
for i, name in enumerate(encoder_names):
    vals = [ablation_results[name]["test_results"]["spearman_per_dim"][d] for d in dim_names]
    axes[1].barh(x + i * width, vals, width, label=name, color=colors[i], alpha=0.85)
axes[1].set_xlabel("Spearman Correlation")
axes[1].set_title("Spearman per Dimension (higher = better)")
axes[1].set_yticks(x + width)
axes[1].set_yticklabels(dim_names, fontsize=8)
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.2, axis="x")

# Accuracy per dimension
for i, name in enumerate(encoder_names):
    vals = [ablation_results[name]["test_results"]["accuracy_per_dim"][d] for d in dim_names]
    axes[2].barh(x + i * width, vals, width, label=name, color=colors[i], alpha=0.85)
axes[2].set_xlabel("Accuracy")
axes[2].set_title("Accuracy per Dimension (higher = better)")
axes[2].set_yticks(x + width)
axes[2].set_yticklabels(dim_names, fontsize=8)
axes[2].legend(fontsize=8)
axes[2].grid(True, alpha=0.2, axis="x")

plt.tight_layout()
plt.show()

In [None]:
# Calibration comparison: uncertainty vs error for each encoder
fig, axes = plt.subplots(1, len(encoder_names), figsize=(6 * len(encoder_names), 5))

if len(encoder_names) == 1:
    axes = [axes]

for ax, name in zip(axes, encoder_names):
    res = ablation_results[name]["test_results"]
    errors = np.abs(res["predictions"] - res["targets"]).flatten()
    uncerts = res["uncertainties"].flatten()

    ax.scatter(uncerts, errors, alpha=0.3, s=10)
    ax.set_xlabel("Predicted Uncertainty (std)")
    ax.set_ylabel("Absolute Error")

    corr = res["calibration"]["error_uncertainty_correlation"]
    ax.set_title(f"{name}\ncorr={corr:.3f}")

    # Trend line
    z = np.polyfit(uncerts, errors, 1)
    p = np.poly1d(z)
    x_line = np.linspace(uncerts.min(), uncerts.max(), 100)
    ax.plot(x_line, p(x_line), "r-", linewidth=2, label="trend")
    ax.legend()
    ax.grid(True, alpha=0.2)

plt.suptitle("Uncertainty Calibration by Encoder\n(Positive slope = well calibrated)", y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Per-dimension calibration for each encoder
print("\n" + "=" * 90)
print("PER-DIMENSION CALIBRATION (Error-Uncertainty Spearman Correlation)")
print("(Positive = uncertainty rises with error = well calibrated)")
print("=" * 90)

header = f"{'Dimension':<20}"
for name in encoder_names:
    header += f"{name:>22}"
print(header)
print("-" * 90)

for dim_idx, dim_name in enumerate(SCHWARTZ_VALUE_ORDER):
    line = f"{dim_name:<20}"
    for name in encoder_names:
        res = ablation_results[name]["test_results"]
        dim_errors = np.abs(res["predictions"][:, dim_idx] - res["targets"][:, dim_idx])
        dim_uncert = res["uncertainties"][:, dim_idx]

        if np.std(dim_uncert) < 1e-8 or np.std(dim_errors) < 1e-8:
            line += f"{'N/A':>22}"
        else:
            corr, _ = stats.spearmanr(dim_uncert, dim_errors)
            line += f"{corr:>22.3f}"
    print(line)

print("-" * 90)
# Overall
line = f"{'OVERALL':<20}"
for name in encoder_names:
    corr = ablation_results[name]["test_results"]["calibration"]["error_uncertainty_correlation"]
    line += f"{corr:>22.3f}"
print(line)
print("=" * 90)

## 6. Matryoshka Dimension Search

`nomic-embed-text-v1.5` supports [Matryoshka Representation Learning](https://huggingface.co/blog/matryoshka), meaning its embeddings can be truncated to lower dimensions with graceful quality degradation.

This is useful because:
- Smaller embeddings → smaller state vectors → smaller MLP → faster training/inference
- May reduce overfitting with the current small dataset
- Helps find the optimal quality/size tradeoff

We test dimensions: 64, 128, 256, 384, 512, 768

In [None]:
from src.vif.encoders import TextEncoder


class TruncatedSBERTEncoder:
    """Wraps an SBERTEncoder to truncate embeddings to a target dimension.

    For Matryoshka models, the first N dimensions are trained to be
    independently useful, so truncation is valid without retraining.
    """

    def __init__(self, base_encoder, target_dim: int):
        self._base = base_encoder
        self._target_dim = target_dim
        self._model_name = f"{base_encoder.model_name}[:{target_dim}]"

    @property
    def embedding_dim(self) -> int:
        return self._target_dim

    @property
    def model_name(self) -> str:
        return self._model_name

    def encode(self, texts: list[str]) -> np.ndarray:
        full = self._base.encode(texts)
        return full[:, :self._target_dim]

    def encode_batch(self, texts: list[str], batch_size: int = 32) -> np.ndarray:
        full = self._base.encode_batch(texts, batch_size=batch_size)
        return full[:, :self._target_dim]


# Only run Matryoshka search if nomic is in the ablation
MATRYOSHKA_DIMS = [64, 128, 256, 384, 512, 768]

# Load the nomic encoder once (reuse from probe if available)
nomic_base = SBERTEncoder("nomic-ai/nomic-embed-text-v1.5")
print(f"Base nomic dim: {nomic_base.embedding_dim}")

matryoshka_results = {}

for dim in MATRYOSHKA_DIMS:
    print(f"\n{'=' * 60}")
    print(f"MATRYOSHKA DIM: {dim}")
    print(f"{'=' * 60}")

    # Reset seeds
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # Truncated encoder
    trunc_encoder = TruncatedSBERTEncoder(nomic_base, dim)
    state_enc = StateEncoder(trunc_encoder, window_size=WINDOW_SIZE, ema_alpha=EMA_ALPHA)
    state_dim = state_enc.state_dim
    print(f"  State dim: {state_dim}")

    # Create datasets
    print("  Caching embeddings...")
    tr_ds = VIFDataset(train_df, state_enc, cache_embeddings=True)
    va_ds = VIFDataset(val_df, state_enc, cache_embeddings=True)
    te_ds = VIFDataset(test_df, state_enc, cache_embeddings=True)

    tr_dl = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True)
    va_dl = DataLoader(va_ds, batch_size=BATCH_SIZE, shuffle=False)
    te_dl = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Model
    model = CriticMLP(input_dim=state_dim, hidden_dim=HIDDEN_DIM, dropout=DROPOUT)
    model.to(device)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"  Model params: {n_params:,}")

    # Train
    criterion = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10, min_lr=1e-5)

    best_val_loss = float("inf")
    best_state = None
    patience = 0

    print("  Training...")
    for epoch in range(N_EPOCHS):
        model.train()
        t_loss = 0.0
        for bx, by in tr_dl:
            bx, by = bx.to(device), by.to(device)
            optimizer.zero_grad()
            loss = criterion(model(bx), by)
            loss.backward()
            optimizer.step()
            t_loss += loss.item()
        t_loss /= len(tr_dl)

        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for bx, by in va_dl:
                bx, by = bx.to(device), by.to(device)
                v_loss += criterion(model(bx), by).item()
        v_loss /= len(va_dl)

        scheduler.step(v_loss)

        if v_loss < best_val_loss - 0.001:
            best_val_loss = v_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience = 0
        else:
            patience += 1

        if patience >= EARLY_STOP_PATIENCE:
            print(f"  Early stop at epoch {epoch + 1}")
            break

    # Evaluate
    model.load_state_dict(best_state)
    model.to(device)
    test_res = evaluate_with_uncertainty(model, te_dl, n_mc_samples=N_MC_SAMPLES, device=device)

    matryoshka_results[dim] = {
        "state_dim": state_dim,
        "n_params": n_params,
        "best_val_loss": best_val_loss,
        "test_results": test_res,
        "epochs": epoch + 1,
    }

    print(f"  Val loss: {best_val_loss:.4f} | Test MSE: {test_res['mse_mean']:.4f} | "
          f"Spearman: {test_res['spearman_mean']:.4f} | Acc: {test_res['accuracy_mean']:.2%}")

print("\nMatryoshka search complete.")

In [None]:
# Matryoshka results table
print("\n" + "=" * 90)
print("MATRYOSHKA DIMENSION SEARCH (nomic-embed-text-v1.5)")
print("=" * 90)
print(f"{'Emb Dim':>8} {'State Dim':>10} {'Params':>12} {'Val Loss':>10} "
      f"{'Test MSE':>10} {'Spearman':>10} {'Accuracy':>10} {'Epochs':>8}")
print("-" * 90)

for dim in MATRYOSHKA_DIMS:
    r = matryoshka_results[dim]
    tr = r["test_results"]
    print(f"{dim:>8} {r['state_dim']:>10} {r['n_params']:>12,} {r['best_val_loss']:>10.4f} "
          f"{tr['mse_mean']:>10.4f} {tr['spearman_mean']:>10.4f} {tr['accuracy_mean']:>10.2%} "
          f"{r['epochs']:>8}")

print("=" * 90)

In [None]:
# Matryoshka dimension vs. performance visualization
dims = MATRYOSHKA_DIMS
mse_vals = [matryoshka_results[d]["test_results"]["mse_mean"] for d in dims]
spearman_vals = [matryoshka_results[d]["test_results"]["spearman_mean"] for d in dims]
acc_vals = [matryoshka_results[d]["test_results"]["accuracy_mean"] for d in dims]
param_vals = [matryoshka_results[d]["n_params"] for d in dims]

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# MSE vs dim
axes[0, 0].plot(dims, mse_vals, "o-", color="#e74c3c", linewidth=2, markersize=8)
axes[0, 0].set_xlabel("Embedding Dimension")
axes[0, 0].set_ylabel("Test MSE")
axes[0, 0].set_title("MSE vs Embedding Dimension")
axes[0, 0].grid(True, alpha=0.3)
# Add baseline reference
if "MiniLM-L6 (baseline)" in ablation_results:
    baseline_mse = ablation_results["MiniLM-L6 (baseline)"]["test_results"]["mse_mean"]
    axes[0, 0].axhline(y=baseline_mse, color="blue", linestyle="--", alpha=0.5,
                        label=f"MiniLM baseline ({baseline_mse:.4f})")
    axes[0, 0].legend(fontsize=8)

# Spearman vs dim
axes[0, 1].plot(dims, spearman_vals, "o-", color="#2ecc71", linewidth=2, markersize=8)
axes[0, 1].set_xlabel("Embedding Dimension")
axes[0, 1].set_ylabel("Test Spearman")
axes[0, 1].set_title("Spearman vs Embedding Dimension")
axes[0, 1].grid(True, alpha=0.3)
if "MiniLM-L6 (baseline)" in ablation_results:
    baseline_sp = ablation_results["MiniLM-L6 (baseline)"]["test_results"]["spearman_mean"]
    axes[0, 1].axhline(y=baseline_sp, color="blue", linestyle="--", alpha=0.5,
                        label=f"MiniLM baseline ({baseline_sp:.4f})")
    axes[0, 1].legend(fontsize=8)

# Accuracy vs dim
axes[1, 0].plot(dims, acc_vals, "o-", color="#f39c12", linewidth=2, markersize=8)
axes[1, 0].set_xlabel("Embedding Dimension")
axes[1, 0].set_ylabel("Test Accuracy")
axes[1, 0].set_title("Accuracy vs Embedding Dimension")
axes[1, 0].grid(True, alpha=0.3)
if "MiniLM-L6 (baseline)" in ablation_results:
    baseline_acc = ablation_results["MiniLM-L6 (baseline)"]["test_results"]["accuracy_mean"]
    axes[1, 0].axhline(y=baseline_acc, color="blue", linestyle="--", alpha=0.5,
                        label=f"MiniLM baseline ({baseline_acc:.2%})")
    axes[1, 0].legend(fontsize=8)

# Params vs dim
axes[1, 1].plot(dims, [p / 1000 for p in param_vals], "o-", color="#9b59b6", linewidth=2, markersize=8)
axes[1, 1].set_xlabel("Embedding Dimension")
axes[1, 1].set_ylabel("Model Parameters (K)")
axes[1, 1].set_title("Model Size vs Embedding Dimension")
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle("Matryoshka Dimension Search (nomic-embed-text-v1.5)", y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

## 7. Latency Benchmark

Measure encoding speed for each model. In the VIF pipeline, encoding happens:
- Once per journal entry (during data caching or real-time inference)
- 50x during MC Dropout (but this is on the MLP, not the encoder)

So encoder latency mainly matters for real-time inference responsiveness.

In [None]:
# Benchmark encoding latency
# Use realistic journal entry lengths
SAMPLE_TEXTS = [
    # Short entry (~50 tokens)
    "Today I helped my colleague with their presentation. It felt good to contribute.",
    # Medium entry (~100 tokens)
    ("I spent the morning reflecting on my career goals. I realized that I've been prioritizing "
     "achievement over my relationships. My partner mentioned feeling neglected, and it struck "
     "me that maybe success isn't worth it if I'm alone at the top. I need to find a better balance."),
    # Long entry with nudge (~250 tokens)
    ("Today was a difficult day at work. The team disagreed about the project direction, and I "
     "found myself torn between following the consensus and speaking up about what I believed was "
     "right. In the end, I chose to voice my concerns, even though it made some colleagues "
     "uncomfortable. I value harmony, but I also can't ignore my principles.\n\n"
     'Nudge: "You mentioned feeling torn between harmony and principles. What would it look '
     'like to honor both?"\n\n'
     "Response: I think I could have been more diplomatic in how I raised my concerns. Instead of "
     "directly challenging the team lead's proposal, I could have framed it as building on their "
     "idea. That way I stay true to my values while respecting the group dynamic."),
]

N_WARMUP = 5
N_RUNS = 50

latency_results = {}

for name, cfg in ENCODER_CONFIGS.items():
    model = encoder_models[name]
    prefix = cfg["prefix"]

    # Prepare texts with prefix
    texts = [prefix + t for t in SAMPLE_TEXTS]

    # Warmup
    for _ in range(N_WARMUP):
        model.encode(texts)

    # Benchmark single encoding
    single_times = []
    for _ in range(N_RUNS):
        t0 = time.perf_counter()
        model.encode([texts[1]])  # Medium entry
        single_times.append(time.perf_counter() - t0)

    # Benchmark batch encoding (3 entries = one window)
    batch_times = []
    for _ in range(N_RUNS):
        t0 = time.perf_counter()
        model.encode(texts)
        batch_times.append(time.perf_counter() - t0)

    latency_results[name] = {
        "single_mean_ms": np.mean(single_times) * 1000,
        "single_std_ms": np.std(single_times) * 1000,
        "batch_mean_ms": np.mean(batch_times) * 1000,
        "batch_std_ms": np.std(batch_times) * 1000,
    }

# Display results
print("=" * 80)
print("ENCODING LATENCY BENCHMARK")
print(f"(Device: CPU, {N_RUNS} runs after {N_WARMUP} warmup)")
print("=" * 80)
print(f"{'Encoder':<25} {'Single (ms)':>15} {'Batch/3 (ms)':>15} {'Speedup':>10}")
print("-" * 80)

baseline_single = latency_results[list(ENCODER_CONFIGS.keys())[0]]["single_mean_ms"]
for name in ENCODER_CONFIGS:
    r = latency_results[name]
    speedup = baseline_single / r["single_mean_ms"]
    print(f"{name:<25} {r['single_mean_ms']:>10.1f} ± {r['single_std_ms']:.1f} "
          f"{r['batch_mean_ms']:>10.1f} ± {r['batch_std_ms']:.1f} "
          f"{speedup:>9.2f}x")

print("=" * 80)

## 8. Conclusions & Recommendations

In [None]:
# Final summary
print("\n" + "=" * 70)
print("EMBEDDING ABLATION STUDY — SUMMARY")
print("=" * 70)

# Probe results
print("\n1. EMBEDDING QUALITY PROBE")
print("-" * 40)
for name in encoder_names:
    mean_opp = np.mean([r["similarity"] for r in probe_results[name]["opposite"]])
    mean_eq = np.mean([r["similarity"] for r in probe_results[name]["equivalent"]])
    gap = mean_eq - mean_opp
    print(f"  {name:<25} opp={mean_opp:.3f}  eq={mean_eq:.3f}  gap={gap:.3f}")

# End-to-end results
print("\n2. END-TO-END CRITIC PERFORMANCE")
print("-" * 40)
for name in encoder_names:
    r = ablation_results[name]
    tr = r["test_results"]
    print(f"  {name:<25} MSE={tr['mse_mean']:.4f}  Spearman={tr['spearman_mean']:.4f}  "
          f"Acc={tr['accuracy_mean']:.2%}  Calib={tr['calibration']['error_uncertainty_correlation']:.3f}")

# Matryoshka best
print("\n3. MATRYOSHKA DIMENSION SEARCH (nomic-embed-text-v1.5)")
print("-" * 40)
best_dim = min(matryoshka_results.keys(),
               key=lambda d: matryoshka_results[d]["test_results"]["mse_mean"])
best_r = matryoshka_results[best_dim]
print(f"  Best dimension: {best_dim}")
print(f"  MSE: {best_r['test_results']['mse_mean']:.4f}")
print(f"  Spearman: {best_r['test_results']['spearman_mean']:.4f}")
print(f"  Params: {best_r['n_params']:,}")

# Latency
print("\n4. LATENCY")
print("-" * 40)
for name in encoder_names:
    r = latency_results[name]
    print(f"  {name:<25} {r['single_mean_ms']:.1f}ms/entry")

print("\n" + "=" * 70)
print("RECOMMENDATIONS")
print("=" * 70)
print("""
Interpret the results above to determine:

1. Does the embedding quality probe predict end-to-end performance?
   (i.e., does better emotional discrimination → better VIF scores?)

2. Is the improvement from a larger encoder significant enough to
   justify the latency and parameter cost increase?

3. What Matryoshka dimension gives the best quality/size tradeoff?

4. Is the performance gap large enough to warrant further investment
   (e.g., domain fine-tuning or emotion-aware embeddings)?
""")
print("=" * 70)