# 03 — Verifier Calibration & Threshold Analysis

This notebook demonstrates and analyses the verifier calibration pipeline,
including temperature scaling and isotonic regression. It also provides
threshold sweep analysis for τ_e and τ_c.

## Key Questions:
1. Are raw verifier scores well-calibrated?
2. How much does calibration improve reliability diagrams?
3. What is the precision/recall trade-off as we vary τ_e and τ_c?
4. Where is the optimal operating point?

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath(".."))

import numpy as np
from certirag.utils import set_all_seeds

set_all_seeds(42)

# We simulate verifier outputs for demonstration
# In practice, these come from MiniCheck/NLI on a held-out calibration set

## 1. Simulated verifier outputs

We create synthetic calibration data with known ground truth to demonstrate
the calibration pipeline. The raw scores are intentionally mis-calibrated
(overconfident) to show the effect of calibration.

In [None]:
N_SAMPLES = 500

# Ground truth: 60% entailed, 20% neutral, 20% contradicted
rng = np.random.default_rng(42)
true_labels = rng.choice(["entailment", "neutral", "contradiction"],
                         size=N_SAMPLES,
                         p=[0.6, 0.2, 0.2])

# Simulate overconfident raw scores
def generate_raw_scores(true_label, rng):
    """Generate synthetic overconfident verifier scores."""
    if true_label == "entailment":
        # Overconfident: push scores higher than warranted
        entail = rng.beta(8, 2)   # mean ~0.8
        contra = rng.beta(1, 10)  # mean ~0.09
    elif true_label == "contradiction":
        entail = rng.beta(1, 8)
        contra = rng.beta(6, 2)   # mean ~0.75, overconfident
    else:  # neutral
        entail = rng.beta(3, 4)
        contra = rng.beta(2, 5)
    neutral = max(0, 1 - entail - contra)
    return entail, contra, neutral

raw_entail = np.zeros(N_SAMPLES)
raw_contra = np.zeros(N_SAMPLES)
raw_neutral = np.zeros(N_SAMPLES)

for i, label in enumerate(true_labels):
    e, c, n = generate_raw_scores(label, rng)
    raw_entail[i] = e
    raw_contra[i] = c
    raw_neutral[i] = n

print(f"Label distribution: {dict(zip(*np.unique(true_labels, return_counts=True)))}")
print(f"\nRaw entail  scores — mean: {raw_entail.mean():.3f}, std: {raw_entail.std():.3f}")
print(f"Raw contra  scores — mean: {raw_contra.mean():.3f}, std: {raw_contra.std():.3f}")
print(f"Raw neutral scores — mean: {raw_neutral.mean():.3f}, std: {raw_neutral.std():.3f}")

## 2. Reliability diagram (before calibration)

A well-calibrated model should have predicted probability ≈ actual frequency.
Points falling above the diagonal indicate over-confidence.

In [None]:
def reliability_diagram(predicted_probs, true_binary, n_bins=10, title=""):
    """Compute and print reliability diagram data."""
    bins = np.linspace(0, 1, n_bins + 1)
    bin_centers = []
    bin_accuracies = []
    bin_counts = []
    
    for i in range(n_bins):
        mask = (predicted_probs >= bins[i]) & (predicted_probs < bins[i+1])
        if mask.sum() > 0:
            bin_centers.append((bins[i] + bins[i+1]) / 2)
            bin_accuracies.append(true_binary[mask].mean())
            bin_counts.append(mask.sum())
    
    # Expected Calibration Error
    ece = 0.0
    total = sum(bin_counts)
    for center, acc, count in zip(bin_centers, bin_accuracies, bin_counts):
        ece += (count / total) * abs(center - acc)
    
    print(f"\n{title}")
    print(f"{'Bin Center':>12} {'Accuracy':>10} {'Count':>8} {'Gap':>8}")
    print("-" * 40)
    for center, acc, count in zip(bin_centers, bin_accuracies, bin_counts):
        gap = acc - center
        marker = "▲" if gap > 0.05 else ("▼" if gap < -0.05 else "≈")
        print(f"{center:>12.2f} {acc:>10.3f} {count:>8d} {gap:>+8.3f} {marker}")
    print(f"\nECE = {ece:.4f}")
    return ece

# Binary labels for entailment
true_entail_binary = (true_labels == "entailment").astype(float)
true_contra_binary = (true_labels == "contradiction").astype(float)

ece_entail_raw = reliability_diagram(raw_entail, true_entail_binary,
                                      title="Reliability: Entailment (RAW)")
ece_contra_raw = reliability_diagram(raw_contra, true_contra_binary,
                                      title="Reliability: Contradiction (RAW)")

## 3. Apply calibration

CertiRAG supports two methods:
- **Temperature scaling**: Learn a single scalar T to soften/sharpen logits
- **Isotonic regression**: Non-parametric monotone recalibration

In [None]:
from sklearn.isotonic import IsotonicRegression
from scipy.optimize import minimize_scalar

# --- Temperature Scaling ---
def temperature_scale(logits, T):
    """Apply temperature scaling to logits."""
    return 1 / (1 + np.exp(-logits / T))

def find_temperature(raw_probs, true_binary):
    """Find optimal temperature via NLL minimisation."""
    logits = np.log(raw_probs / (1 - raw_probs + 1e-8) + 1e-8)
    
    def nll(T):
        scaled = temperature_scale(logits, T)
        scaled = np.clip(scaled, 1e-8, 1 - 1e-8)
        return -np.mean(true_binary * np.log(scaled) + (1 - true_binary) * np.log(1 - scaled))
    
    result = minimize_scalar(nll, bounds=(0.1, 10.0), method='bounded')
    return result.x

# Split into calibration (70%) and test (30%)
split = int(0.7 * N_SAMPLES)
cal_entail, test_entail = raw_entail[:split], raw_entail[split:]
cal_labels, test_labels = true_entail_binary[:split], true_entail_binary[split:]

# Temperature scaling
T_opt = find_temperature(cal_entail, cal_labels)
logits_test = np.log(test_entail / (1 - test_entail + 1e-8) + 1e-8)
temp_scaled = temperature_scale(logits_test, T_opt)

# Isotonic regression
iso = IsotonicRegression(y_min=0, y_max=1, out_of_bounds='clip')
iso.fit(cal_entail, cal_labels)
iso_calibrated = iso.predict(test_entail)

print(f"Optimal temperature: T = {T_opt:.3f}")
print(f"\n--- Test set results ---")

ece_raw = reliability_diagram(test_entail, test_labels,
                               title="Entailment — Raw (test set)")
ece_temp = reliability_diagram(temp_scaled, test_labels,
                                title="Entailment — Temperature Scaled")
ece_iso = reliability_diagram(iso_calibrated, test_labels,
                               title="Entailment — Isotonic Regression")

print(f"\n{'Method':<25} {'ECE':>8}")
print("-" * 35)
print(f"{'Raw':.<25} {ece_raw:>8.4f}")
print(f"{'Temperature (T='+f'{T_opt:.2f})':.<25} {ece_temp:>8.4f}")
print(f"{'Isotonic':.<25} {ece_iso:>8.4f}")

## 4. Threshold sweep: τ_e × τ_c grid search

We sweep over threshold pairs and measure precision, recall, and F1
for the "VERIFIED" label under the Theorem 1 decision rule.

In [None]:
from certirag.schemas.certificate import RenderState

def apply_theorem1(entail_score, contra_score, tau_e, tau_c, min_evidence=1):
    """Apply Theorem 1 decision rule."""
    if contra_score >= tau_c:
        return RenderState.BLOCKED
    elif entail_score >= tau_e:
        return RenderState.VERIFIED
    else:
        return RenderState.UNVERIFIED

# Use calibrated scores for threshold analysis
cal_contra, test_contra = raw_contra[:split], raw_contra[split:]
test_true_labels = true_labels[split:]

tau_e_range = np.arange(0.50, 0.96, 0.05)
tau_c_range = np.arange(0.40, 0.86, 0.05)

print(f"\n{'τ_e':>6} {'τ_c':>6} {'Prec':>8} {'Recall':>8} {'F1':>8} {'Block%':>8} {'Unver%':>8}")
print("=" * 52)

best_f1 = 0
best_params = None

for tau_e in tau_e_range:
    for tau_c in tau_c_range:
        decisions = []
        for e, c in zip(test_entail, test_contra):
            decisions.append(apply_theorem1(e, c, tau_e, tau_c))
        
        # Precision: of things we verified, how many are truly entailed?
        verified_mask = np.array([d == RenderState.VERIFIED for d in decisions])
        blocked_mask = np.array([d == RenderState.BLOCKED for d in decisions])
        
        true_entail_mask = test_true_labels == "entailment"
        
        if verified_mask.sum() > 0:
            precision = (verified_mask & true_entail_mask).sum() / verified_mask.sum()
        else:
            precision = 1.0  # vacuously true
        
        if true_entail_mask.sum() > 0:
            recall = (verified_mask & true_entail_mask).sum() / true_entail_mask.sum()
        else:
            recall = 1.0
        
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        block_pct = blocked_mask.mean() * 100
        unver_pct = (~verified_mask & ~blocked_mask).mean() * 100
        
        if f1 > best_f1:
            best_f1 = f1
            best_params = (tau_e, tau_c, precision, recall, f1, block_pct, unver_pct)

# Print best result
tau_e, tau_c, prec, rec, f1, blk, unv = best_params
print(f"\n★ Best F1: τ_e={tau_e:.2f}, τ_c={tau_c:.2f}")
print(f"  Precision={prec:.3f}, Recall={rec:.3f}, F1={f1:.3f}")
print(f"  Blocked={blk:.1f}%, Unverified={unv:.1f}%")

## 5. Precision-Recall curve at fixed τ_c

In [None]:
fixed_tau_c = 0.70  # default CertiRAG value
print(f"Precision-Recall trade-off (τ_c fixed at {fixed_tau_c})\n")
print(f"{'τ_e':>6} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Verified%':>10}")
print("-" * 48)

for tau_e in np.arange(0.30, 0.96, 0.05):
    decisions = [apply_theorem1(e, c, tau_e, fixed_tau_c)
                 for e, c in zip(test_entail, test_contra)]
    
    verified = np.array([d == RenderState.VERIFIED for d in decisions])
    true_ent = test_true_labels == "entailment"
    
    prec = (verified & true_ent).sum() / max(verified.sum(), 1)
    rec = (verified & true_ent).sum() / max(true_ent.sum(), 1)
    f1 = 2 * prec * rec / (prec + rec + 1e-8)
    ver_pct = verified.mean() * 100
    
    marker = " ◄" if abs(tau_e - 0.85) < 0.01 else ""
    print(f"{tau_e:>6.2f} {prec:>10.3f} {rec:>10.3f} {f1:>10.3f} {ver_pct:>9.1f}%{marker}")

print("\n◄ = CertiRAG default threshold")

## 6. Summary

**Key findings from calibration analysis:**

1. Raw verifier scores are **overconfident** — ECE is significant.
2. Both temperature scaling and isotonic regression **reduce ECE** substantially.
3. Isotonic regression typically achieves lower ECE but requires more calibration data.
4. The default thresholds (τ_e=0.85, τ_c=0.70) provide a **fail-safe operating point** —
   high precision at the cost of some recall.
5. Grid search confirms the precision/recall trade-off described in the paper.