# Week 2 — Notebook 4: Alignment Pain Points & Remedies

## Context

This notebook documents and diagnoses the real-world technical hurdles observed during preference-based alignment (DPO/ORPO) on a customer support model. Data is processed via Spark (~90k rows), scored by VJs (Virtual Judges), and used to construct chosen/rejected pairs.

---

## Pain Points Covered

| # | Pain Point | Section |
|---|------------|---------|
| 1 | **Alignment Tax** — style improves, correctness decays | §2 |
| 2 | **Evaluation Uncertainty** — unknown VJ precision & recall | §3 |
| 3 | **Negative Correlation in Rule-Based Scoring** — rejected > chosen on problem_solution | §4 |
| 4 | **Sample Efficiency** — 24k → 700 high-quality pairs after filtering | §5 |
| 5 | **Reward Non-Determinism** — correctness scores feel random | §6 |

## Remedies Implemented

| Remedy | Concept | Section |
|--------|---------|--------|
| Multi-Objective Optimization (MOO) | Correctness floor constraint | §7 |
| Raised KL Penalty (β) | Reference model regularization | §8 |
| Judge-on-Judge Meta-Eval | Estimating VJ precision | §9 |
| SFT Warm-up → DPO | Differentiated training phases | §10 |
| On-Policy DPO | Sample from current model, not GPT-5/Kimi | §11 |

## 0. Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from pathlib import Path
from scipy import stats

sns.set_theme(style="whitegrid", palette="muted")
EVAL_DIR = Path("../data/eval_results")
EVAL_DIR.mkdir(parents=True, exist_ok=True)

np.random.seed(42)

# ── Synthetic data to illustrate each pain point ─────────────────────────────
# Replace with real df_scores from Notebook 03 if available:
# df_scores = pd.read_json(EVAL_DIR / "vj_scores.jsonl", lines=True)

n = 200
rng = np.random.default_rng(42)
df_scores = pd.concat([
    pd.DataFrame({
        "checkpoint":      "base",
        "correctness":     rng.normal(0.72, 0.10, n).clip(0,1),
        "groundedness":    rng.normal(0.70, 0.10, n).clip(0,1),
        "problem_solution":rng.normal(0.68, 0.12, n).clip(0,1),
        "style":           rng.normal(0.55, 0.12, n).clip(0,1),
    }),
    pd.DataFrame({
        "checkpoint":      "sft",
        "correctness":     rng.normal(0.76, 0.09, n).clip(0,1),
        "groundedness":    rng.normal(0.73, 0.09, n).clip(0,1),
        "problem_solution":rng.normal(0.72, 0.11, n).clip(0,1),
        "style":           rng.normal(0.57, 0.11, n).clip(0,1),
    }),
    pd.DataFrame({
        "checkpoint":      "dpo_naive",  # no correctness floor, low β
        "correctness":     rng.normal(0.64, 0.11, n).clip(0,1),   # ← decays
        "groundedness":    rng.normal(0.65, 0.10, n).clip(0,1),
        "problem_solution":rng.normal(0.63, 0.13, n).clip(0,1),   # ← decays
        "style":           rng.normal(0.73, 0.09, n).clip(0,1),   # ← improves
    }),
    pd.DataFrame({
        "checkpoint":      "dpo_fixed",  # correctness floor + β=0.2
        "correctness":     rng.normal(0.74, 0.09, n).clip(0,1),
        "groundedness":    rng.normal(0.72, 0.09, n).clip(0,1),
        "problem_solution":rng.normal(0.71, 0.11, n).clip(0,1),
        "style":           rng.normal(0.70, 0.09, n).clip(0,1),
    }),
], ignore_index=True)

DIMS = ["correctness", "groundedness", "problem_solution", "style"]
print(df_scores.groupby("checkpoint")[DIMS].mean().round(3))

---
## Pain Point 1 — The Alignment Tax

> When optimizing for style/tone, we observe simultaneous decay in correctness and problem_solution.  
> This is **not** a model failure — it is a **data construction failure**: the chosen/rejected pairs did not  
> hold correctness constant while varying style.

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4), sharey=False)
order = ["base", "sft", "dpo_naive", "dpo_fixed"]
colors = ["#4878CF", "#6ACC65", "#D65F5F", "#B47CC7"]

for ax, dim in zip(axes, DIMS):
    means = df_scores.groupby("checkpoint")[dim].mean().reindex(order)
    cis   = df_scores.groupby("checkpoint")[dim].sem().reindex(order) * 1.96
    ax.bar(range(len(order)), means, color=colors, yerr=cis, capsize=4, width=0.6)
    ax.set_xticks(range(len(order)))
    ax.set_xticklabels(order, rotation=25, ha="right", fontsize=9)
    ax.set_title(dim, fontsize=12)
    ax.set_ylim(0.4, 0.9)
    for i, v in enumerate(means):
        ax.text(i, v + cis.iloc[i] + 0.01, f"{v:.2f}", ha="center", fontsize=8)

plt.suptitle("Pain Point 1: Alignment Tax\ndpo_naive shows style↑ but correctness↓",
             fontsize=13, y=1.03)
plt.tight_layout()
plt.savefig(EVAL_DIR / "pain1_alignment_tax.png", dpi=150, bbox_inches="tight")
plt.show()

---
## Pain Point 2 — Evaluation Uncertainty

> We do not know the precision and recall of our VJ scoring pipeline.  
> Without calibration, a high score may mean nothing.

In [None]:
# Simulate VJ (weak judge) vs. strong judge scores for 100 samples
n_meta = 100
true_scores   = rng.uniform(0.3, 1.0, n_meta)

# VJ with random noise (simulates non-determinism in correctness scoring)
vj_scores     = (true_scores + rng.normal(0, 0.15, n_meta)).clip(0, 1)

# Bucketise to high/low
THRESHOLD = 0.65
true_positive  = ((vj_scores >= THRESHOLD) & (true_scores >= THRESHOLD)).sum()
false_positive = ((vj_scores >= THRESHOLD) & (true_scores <  THRESHOLD)).sum()
false_negative = ((vj_scores <  THRESHOLD) & (true_scores >= THRESHOLD)).sum()
true_negative  = ((vj_scores <  THRESHOLD) & (true_scores <  THRESHOLD)).sum()

precision = true_positive / (true_positive + false_positive + 1e-9)
recall    = true_positive / (true_positive + false_negative + 1e-9)

fig, ax = plt.subplots(figsize=(5, 5))
ax.scatter(true_scores, vj_scores, alpha=0.5, s=30)
ax.axvline(THRESHOLD, color="red",  linestyle="--", label="threshold (true)")
ax.axhline(THRESHOLD, color="blue", linestyle="--", label="threshold (VJ)")
ax.set_xlabel("Strong judge score (ground truth)")
ax.set_ylabel("VJ score")
ax.set_title(f"Pain Point 2: VJ Calibration\nPrecision={precision:.2f}  Recall={recall:.2f}")
ax.legend()
plt.tight_layout()
plt.savefig(EVAL_DIR / "pain2_evaluation_uncertainty.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"Precision: {precision:.2f}")
print(f"Recall:    {recall:.2f}")
print("\n→ Remedy (§9): Judge-on-Judge meta-eval on 500 samples.")

---
## Pain Point 3 — Negative Correlation in Rule-Based Scoring

> When rule-based scores select 5,000 pairs, the **rejected** group averages *higher*  
> problem_solution than the chosen group — a structural inversion.

In [None]:
# Simulate the inversion
n_pairs = 5000

# Rule selects on style delta alone — problem_solution is NOT controlled
df_pairs = pd.DataFrame({
    "style_chosen":            rng.normal(0.72, 0.10, n_pairs).clip(0,1),
    "style_rejected":          rng.normal(0.48, 0.12, n_pairs).clip(0,1),
    # problem_solution negatively correlated with style in this dataset
    "ps_chosen":               rng.normal(0.58, 0.15, n_pairs).clip(0,1),
    "ps_rejected":             rng.normal(0.66, 0.14, n_pairs).clip(0,1),  # ← perversely higher
})

ps_c = df_pairs["ps_chosen"].mean()
ps_r = df_pairs["ps_rejected"].mean()

fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(["chosen", "rejected"], [ps_c, ps_r], color=["steelblue", "tomato"], width=0.4)
ax.set_ylabel("Avg problem_solution score")
ax.set_title(f"Pain Point 3: Negative Correlation in Rule-Based Selection\n"
             f"rejected ({ps_r:.3f}) > chosen ({ps_c:.3f}) on problem_solution")
ax.set_ylim(0, 1)
plt.tight_layout()
plt.savefig(EVAL_DIR / "pain3_negative_correlation.png", dpi=150, bbox_inches="tight")
plt.show()

print("→ Remedy (§7): Add problem_solution floor constraint to pair selection.")

---
## Pain Point 4 — Sample Efficiency

> Aggressive multi-VJ filtering collapses 24k pairs → ~700 high-quality pairs.  
> Each additional filter cuts yield multiplicatively.

In [None]:
funnel = pd.DataFrame([
    {"stage": "Raw rollouts",                  "n": 90_000},
    {"stage": "Deduplication + quality",        "n": 24_000},
    {"stage": "Reward delta ≥ 0.05",            "n": 8_500},
    {"stage": "Correctness floor ≥ 0.60",       "n": 3_200},
    {"stage": "Groundedness floor ≥ 0.55",      "n": 1_400},
    {"stage": "Problem_solution floor ≥ 0.60",  "n":   700},
])
funnel["retention_%"] = (funnel["n"] / funnel["n"].iloc[0] * 100).round(2)

fig, ax = plt.subplots(figsize=(9, 4))
bars = ax.barh(funnel["stage"][::-1], funnel["n"][::-1], color=sns.color_palette("Blues_r", len(funnel)))
for bar, row in zip(bars, funnel[::-1].itertuples()):
    ax.text(bar.get_width() + 200, bar.get_y() + bar.get_height()/2,
            f"{row.n:,}  ({row._4:.1f}%)", va="center", fontsize=9)
ax.set_xlabel("Number of pairs")
ax.set_title("Pain Point 4: Sample Efficiency Funnel")
ax.set_xlim(0, 100_000)
plt.tight_layout()
plt.savefig(EVAL_DIR / "pain4_sample_efficiency.png", dpi=150, bbox_inches="tight")
plt.show()

print(funnel.to_string(index=False))
print("\n→ Remedy (§11): On-policy sampling produces more filterable signal per rollout.")

---
## Pain Point 5 — Reward Non-Determinism

> Correctness scoring feels random — the same response can score 0.55 or 0.78  
> across two VJ calls. This erodes trust in the chosen/rejected delta.

In [None]:
n_samples = 300
true_quality = rng.uniform(0.3, 1.0, n_samples)
# Two independent VJ runs with noise
run1 = (true_quality + rng.normal(0, 0.12, n_samples)).clip(0, 1)
run2 = (true_quality + rng.normal(0, 0.12, n_samples)).clip(0, 1)
score_variance = np.abs(run1 - run2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.scatter(run1, run2, alpha=0.4, s=20)
ax1.plot([0,1],[0,1], "r--", linewidth=1)
r, _ = stats.pearsonr(run1, run2)
ax1.set_xlabel("VJ Run 1")
ax1.set_ylabel("VJ Run 2")
ax1.set_title(f"Correctness Score Agreement (r={r:.3f})")

ax2.hist(score_variance, bins=30, color="salmon", edgecolor="white")
ax2.axvline(score_variance.mean(), color="darkred", linestyle="--",
            label=f"mean |Δ| = {score_variance.mean():.3f}")
ax2.set_xlabel("|Score run1 - run2|")
ax2.set_title("Pain Point 5: Score Non-Determinism")
ax2.legend()

plt.tight_layout()
plt.savefig(EVAL_DIR / "pain5_nondeterminism.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"Mean absolute score variance across two VJ runs: {score_variance.mean():.3f}")
print("→ Remedy: Temperature=0 for VJ calls + ensemble scoring (avg 3 runs).")

---
## Remedy 1 — Multi-Objective Optimization: Correctness Floor

> Instead of a simple weighted sum, use **constrained selection**:  
> a pair is only included if the *chosen* response passes a minimum correctness threshold,  
> regardless of how large the style delta is.

In [None]:
n = 5000
df_moo = pd.DataFrame({
    "correctness_chosen":      rng.normal(0.65, 0.15, n).clip(0, 1),
    "style_delta":             rng.normal(0.15, 0.10, n).clip(0, 0.5),
    "problem_solution_chosen": rng.normal(0.62, 0.14, n).clip(0, 1),
})

# Weighted sum (no constraint)
df_unconstrained = df_moo[df_moo["style_delta"] > 0.10]

# MOO with correctness floor
CORRECTNESS_FLOOR = 0.60
df_constrained = df_moo[
    (df_moo["style_delta"] > 0.10)
    & (df_moo["correctness_chosen"] >= CORRECTNESS_FLOOR)
]

compare = pd.DataFrame({
    "pairs": [len(df_unconstrained), len(df_constrained)],
    "avg_correctness": [
        df_unconstrained["correctness_chosen"].mean(),
        df_constrained["correctness_chosen"].mean(),
    ],
    "avg_problem_solution": [
        df_unconstrained["problem_solution_chosen"].mean(),
        df_constrained["problem_solution_chosen"].mean(),
    ],
}, index=["unconstrained", "MOO (floor)"])

print(compare.round(3))
print(f"\nPairs removed by floor: {len(df_unconstrained) - len(df_constrained):,}")
print(f"Correctness improvement: {compare.loc['MOO (floor)','avg_correctness'] - compare.loc['unconstrained','avg_correctness']:+.3f}")

---
## Remedy 2 — Raised KL Penalty (β) in DPO

In [None]:
# Visualise how β controls trade-off between reward and KL divergence
betas = [0.05, 0.10, 0.15, 0.20, 0.30, 0.50]
# Hypothetical: higher β → less style gain but more correctness preserved
style_gain       = [0.18, 0.15, 0.13, 0.10, 0.07, 0.04]
correctness_loss = [0.12, 0.08, 0.05, 0.02, 0.01, 0.005]

fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(betas, style_gain,       "o-", color="steelblue", label="Style gain")
ax.plot(betas, correctness_loss, "s-", color="tomato",    label="Correctness loss")
ax.axvline(0.20, linestyle="--", color="green", label="Selected β=0.20")
ax.set_xlabel("DPO β (KL penalty)")
ax.set_ylabel("Score delta from base")
ax.set_title("Remedy 2: β Trade-off — More KL = Less Tax")
ax.legend()
plt.tight_layout()
plt.savefig(EVAL_DIR / "remedy2_beta_tradeoff.png", dpi=150, bbox_inches="tight")
plt.show()
print("Selected β=0.20 as the knee point: meaningful style improvement with minimal correctness decay.")

---
## Remedy 3 — Judge-on-Judge Meta-Eval

See `03_model_evaluation.ipynb §5` for the live implementation.  
Here we summarize the framework and interpret the output.

In [None]:
# Interpret agreement scores
agreement_results = {
    "correctness":      {"r": 0.81, "interpretation": "good — reliable signal"},
    "groundedness":     {"r": 0.74, "interpretation": "good"},
    "problem_solution": {"r": 0.62, "interpretation": "moderate — use with caution"},
    "style":            {"r": 0.53, "interpretation": "⚠ low — subjective, high noise"},
}

df_agree = pd.DataFrame(agreement_results).T.reset_index()
df_agree.columns = ["dimension", "pearson_r", "interpretation"]
df_agree["pearson_r"] = df_agree["pearson_r"].astype(float)

fig, ax = plt.subplots(figsize=(7, 3))
colors = ["green" if r >= 0.7 else "orange" if r >= 0.5 else "red"
          for r in df_agree["pearson_r"]]
ax.barh(df_agree["dimension"], df_agree["pearson_r"], color=colors)
ax.axvline(0.70, linestyle="--", color="green",  label="Good (r≥0.70)")
ax.axvline(0.50, linestyle="--", color="orange", label="Moderate (r≥0.50)")
ax.set_xlim(0, 1)
ax.set_xlabel("Pearson r (VJ vs. GPT-4o)")
ax.set_title("Remedy 3: Judge Agreement (Meta-Eval)")
ax.legend(loc="lower right")
plt.tight_layout()
plt.savefig(EVAL_DIR / "remedy3_judge_agreement.png", dpi=150, bbox_inches="tight")
plt.show()

for _, row in df_agree.iterrows():
    print(f"  {row['dimension']:<22} r={row['pearson_r']:.2f}  [{row['interpretation']}]")

---
## Remedy 4 — SFT Warm-up Strategy

See `02_finetuning_qlora.ipynb` for implementation.  
This cell explains *why* it works.

In [None]:
# Visual: training trajectory — correctness over training steps
steps = np.arange(0, 500, 10)

# Without SFT warm-up: correctness drops as DPO optimizes style
correctness_no_warmup = 0.72 - 0.10 * (1 - np.exp(-steps / 400)) + rng.normal(0, 0.01, len(steps))

# With SFT warm-up: starts higher, decays less
correctness_warmup = 0.76 - 0.04 * (1 - np.exp(-steps / 400)) + rng.normal(0, 0.01, len(steps))

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(steps, correctness_no_warmup, color="tomato",    label="DPO only (no warm-up)")
ax.plot(steps, correctness_warmup,    color="steelblue", label="SFT warm-up → DPO")
ax.axvspan(0, 50, alpha=0.1, color="green", label="SFT warm-up phase")
ax.set_xlabel("DPO training steps")
ax.set_ylabel("Correctness score")
ax.set_title("Remedy 4: SFT Warm-up Anchors Correctness During DPO")
ax.legend()
ax.set_ylim(0.5, 0.9)
plt.tight_layout()
plt.savefig(EVAL_DIR / "remedy4_sft_warmup.png", dpi=150, bbox_inches="tight")
plt.show()

---
## Remedy 5 — On-Policy DPO

> **Problem:** Using off-policy data from GPT-5/Kimi as ground truth means the distribution  
> gap between the reference responses and the model's actual outputs is large.  
> The model is being trained on errors it doesn't even make.
>
> **Fix:** Sample rollouts from the *current* checkpoint of the Qwen model → score them → use for next DPO round.

In [None]:
# Pseudo-code / workflow for iterative on-policy DPO

ON_POLICY_WORKFLOW = """
Iterative On-Policy DPO Loop
════════════════════════════

Round 0:
  model_0 = Qwen2.5-7B-Instruct (base)

For round t = 1, 2, 3, ...:
  1. Sample K rollouts from model_{t-1} for each prompt in pool
     └── temperature=0.7, top_p=0.9
  2. Score rollouts with VJ (correctness, groundedness, ps, style)
  3. Construct pairs:
     └── chosen   = rollout with highest weighted reward
     └── rejected = rollout with lowest weighted reward
     └── Apply correctness floor ≥ 0.60 (Remedy 1)
  4. Train DPO (β=0.20) for 1 epoch → model_t
  5. Evaluate model_t on held-out set
     └── Early stop if correctness decays > 2% vs. model_{t-1}

Advantages:
  - Data targets the model's actual failure modes
  - Distribution gap is minimized (on-policy)
  - Sample efficiency improves each round
"""

print(ON_POLICY_WORKFLOW)

In [None]:
# Visualise expected improvement curve under on-policy vs. off-policy DPO
rounds = np.arange(0, 5)

style_offpolicy = [0.55, 0.65, 0.68, 0.69, 0.70]
style_onpolicy  = [0.55, 0.67, 0.72, 0.74, 0.76]
correct_offpolicy = [0.72, 0.65, 0.63, 0.62, 0.61]
correct_onpolicy  = [0.72, 0.70, 0.70, 0.71, 0.72]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(rounds, style_offpolicy, "o--", color="tomato",    label="Off-policy (GPT-5/Kimi)")
ax1.plot(rounds, style_onpolicy,  "s-",  color="steelblue", label="On-policy (Qwen rollouts)")
ax1.set_title("Style Score per Round")
ax1.set_xlabel("DPO Round")
ax1.set_ylabel("Avg style score")
ax1.set_ylim(0.4, 0.9)
ax1.legend()

ax2.plot(rounds, correct_offpolicy, "o--", color="tomato",    label="Off-policy")
ax2.plot(rounds, correct_onpolicy,  "s-",  color="steelblue", label="On-policy")
ax2.set_title("Correctness Score per Round")
ax2.set_xlabel("DPO Round")
ax2.set_ylabel("Avg correctness score")
ax2.set_ylim(0.4, 0.9)
ax2.legend()

plt.suptitle("Remedy 5: On-Policy DPO Preserves Correctness While Improving Style",
             fontsize=13, y=1.03)
plt.tight_layout()
plt.savefig(EVAL_DIR / "remedy5_onpolicy_dpo.png", dpi=150, bbox_inches="tight")
plt.show()

---
## Summary: Pain Points → Remedies Mapping

| Pain Point | Root Cause | Remedy | Notebook |
|------------|-----------|--------|----------|
| Alignment tax | Pairs don't hold correctness constant | MOO correctness floor | §7 / NB01 |
| Evaluation uncertainty | VJ not calibrated | Judge-on-Judge meta-eval | §9 / NB03 |
| Negative correlation | Rule-based selection ignores cross-metric effects | Multi-floor pair selection | §7 / NB01 |
| Sample collapse | Over-filtering across multiple VJs | Relax individual floors; prioritize on-policy data | §11 |
| Score non-determinism | LLM judge temperature > 0 | Set VJ temperature=0, ensemble 3 runs | §6 |
| Metric decay during DPO | KL penalty too weak | Raise β from 0.10 → 0.20 | §8 / NB02 |
| Off-policy gap | Training on GPT-5/Kimi errors, not model's errors | Iterative on-policy DPO | §11 |

---
> **Artifacts:** All charts saved to `../data/eval_results/`  
> **Next step:** Integrate remedies into NB01 pair construction and re-run NB02 → NB03 to verify.