# Appendix B — Calibration Methods

### What is this appendix about?

In the main notebook (Sections 5-6), we built $U(s)$ (uncertainty) and $C(s)$ (confidence). The agent says "I am 80% confident about this cell". But **is that true?** When it says 80%, is it actually correct 80% of the time?

This is the question of **calibration**: does the agent's confidence reflect its true performance?

This appendix presents the tools for measuring the quality of this calibration:
- **ECE**: the mean gap between confidence and actual performance
- **Reliability diagram**: the standard calibration visualization
- **MI** (Metacognitive Index): does the agent know *where* it is wrong?

**Prerequisites:** [00_prism_concepts.ipynb](00_prism_concepts.ipynb) (Sections 4-6)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from scipy.stats import spearmanr
from IPython.display import display
import sys, os

sys.path.insert(0, os.path.abspath('..'))
from prism.pedagogy.toy_grid import ToyGrid

%matplotlib inline
plt.rcParams['figure.dpi'] = 100

---
## 1. What is calibration?

### The weather example

A weather forecaster says "80% chance of rain". How do we know if they are reliable?

We look at **all the days when they said "80%"**. If it actually rained about 80% of those days, they are **well calibrated**. If it only rained 50% of the time, they are **overconfident** — they announce probabilities that are too high.

### For PRISM

The agent says $C(s) = 0.8$ — "I am 80% confident in my prediction for cell $s$".

We can verify: among all cells where the agent says $C \approx 0.8$, is its SR prediction correct ~80% of the time?

| Situation | What it means | Consequence |
|-----------|-------------------|-------------|
| $C(s) \approx accuracy(s)$ | **Well calibrated** — confidence reflects reality | The agent can rely on its own signal C |
| $C(s) > accuracy(s)$ | **Overconfident** — the agent believes it knows but is wrong | Dangerous: it does not explore enough |
| $C(s) < accuracy(s)$ | **Underconfident** — the agent doubts when it is right | Less severe but wastes exploration |

### Why this matters

The entire metacognitive chain in PRISM (adaptive exploration in Section 7, change detection in Section 8, IDK signal in Section 9) relies on $C(s)$. If $C$ does not reflect reality, these mechanisms make poor decisions.

The graph below shows 3 synthetic examples (well calibrated, overconfident, underconfident).

In [None]:
# Générer des données synthétiques de calibration
np.random.seed(42)
n = 200

# Cas 1 : bien calibré
conf_good = np.random.beta(2, 2, n)
acc_good = np.array([np.random.binomial(1, c) for c in conf_good])

# Cas 2 : sur-confiant
conf_over = np.random.beta(5, 1, n)
acc_over = np.array([np.random.binomial(1, max(0, c - 0.3)) for c in conf_over])

# Cas 3 : sous-confiant
conf_under = np.random.beta(1, 3, n)
acc_under = np.array([np.random.binomial(1, min(1, c + 0.3)) for c in conf_under])

fig, axes = plt.subplots(1, 3, figsize=(13, 4))
cases = [
    ('Bien calibré', conf_good, acc_good),
    ('Sur-confiant', conf_over, acc_over),
    ('Sous-confiant', conf_under, acc_under),
]

for ax, (name, conf, acc) in zip(axes, cases):
    n_bins = 8
    bin_edges = np.linspace(0, 1, n_bins + 1)
    bin_centers = []
    bin_accs = []
    for i in range(n_bins):
        mask = (conf >= bin_edges[i]) & (conf < bin_edges[i+1])
        if i == n_bins - 1:
            mask = (conf >= bin_edges[i]) & (conf <= bin_edges[i+1])
        if mask.sum() > 0:
            bin_centers.append(conf[mask].mean())
            bin_accs.append(acc[mask].mean())

    ax.plot([0, 1], [0, 1], 'r--', linewidth=2, label='Calibration parfaite')
    ax.scatter(bin_centers, bin_accs, s=60, c='steelblue', edgecolors='white', zorder=3)
    ax.plot(bin_centers, bin_accs, '-', color='steelblue', alpha=0.5)
    ax.set_xlabel('Confiance de l\'agent')
    ax.set_ylabel('Accuracy réelle')
    ax.set_title(name)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Lecture des 3 graphes :")
print()
print("Chaque graphe montre un 'reliability diagram' :")
print("  Axe horizontal = ce que l'agent DIT (sa confiance)")
print("  Axe vertical   = ce qui est VRAI (son accuracy réelle)")
print("  Diagonale rouge = calibration parfaite (confiance = réalité)")
print("  Chaque point = un groupe de prédictions avec des confiances similaires")
print()
print("Gauche — Bien calibré :")
print("  Les points suivent la diagonale.")
print("  Quand l'agent dit '80% confiant', il a raison ~80% du temps. ✓")
print()
print("Centre — Sur-confiant :")
print("  Les points sont EN DESSOUS de la diagonale.")
print("  L'agent dit '90% confiant' mais n'a raison que ~60% du temps.")
print("  → Il se surestime. Dangereux : il n'explore pas assez.")
print()
print("Droite — Sous-confiant :")
print("  Les points sont AU-DESSUS de la diagonale.")
print("  L'agent dit '20% confiant' mais a raison ~50% du temps.")
print("  → Il se sous-estime. Moins grave, mais gaspille de l'exploration.")

---
## 2. Expected Calibration Error (ECE)

### The need

We want a **single number** that summarizes "how well calibrated the agent is". This is the role of the ECE.

### The intuition

The idea is simple:
1. Group predictions by confidence level (e.g., all cells where $C \in [0.7, 0.8]$)
2. For each group, compare the mean confidence with the mean accuracy
3. Take the weighted average of the gaps, weighted by group size

### The formula

$$ECE = \sum_{b=1}^{B} \frac{|B_b|}{N} \cdot |accuracy_b - confidence_b|$$

- $B$ = number of groups (bins)
- $B_b$ = set of predictions in bin $b$
- $|B_b|/N$ = proportion of predictions in this bin (weight)
- $|accuracy_b - confidence_b|$ = gap in this bin

### How to read the result

- **ECE = 0** → perfect calibration (each bin has accuracy = confidence)
- **ECE = 0.05** → good (5% average gap)
- **ECE = 0.20** → poor (20% gap)

**PRISM target**: ECE < 0.15

### The number of bins

The choice of $B$ (number of bins) is a trade-off:
- **Few bins** (3-5): each bin contains a lot of data → stable estimates, but resolution is lost
- **Many bins** (15-20): fine resolution, but some bins may be nearly empty → noisy estimates

In [None]:
def compute_ece(confidences, accuracies, n_bins=10):
    """Calcule l'ECE et les données du reliability diagram."""
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_confs = []
    bin_accs = []
    bin_counts = []

    for i in range(n_bins):
        lo, hi = bin_boundaries[i], bin_boundaries[i + 1]
        mask = (confidences >= lo) & (confidences < hi)
        if i == n_bins - 1:  # dernier bin inclut la borne sup
            mask = (confidences >= lo) & (confidences <= hi)

        count = mask.sum()
        if count > 0:
            bin_confs.append(confidences[mask].mean())
            bin_accs.append(accuracies[mask].mean())
            bin_counts.append(count)
        else:
            bin_confs.append(np.nan)
            bin_accs.append(np.nan)
            bin_counts.append(0)

    bin_confs = np.array(bin_confs)
    bin_accs = np.array(bin_accs)
    bin_counts = np.array(bin_counts)

    # ECE
    valid = bin_counts > 0
    n_total = bin_counts.sum()
    ece = np.sum((bin_counts[valid] / n_total) * np.abs(bin_accs[valid] - bin_confs[valid]))

    return ece, bin_confs, bin_accs, bin_counts

# Test
ece, _, _, _ = compute_ece(conf_good, acc_good)
print(f"ECE (bien calibr\u00e9) : {ece:.4f}")

ece_over, _, _, _ = compute_ece(conf_over, acc_over)
print(f"ECE (sur-confiant) : {ece_over:.4f}")

In [None]:
def plot_ece_interactive(n_bins, calibration_type):
    """Visualise l'ECE avec différents nombres de bins."""
    data = {
        'Bien calibré': (conf_good, acc_good),
        'Sur-confiant': (conf_over, acc_over),
        'Sous-confiant': (conf_under, acc_under),
    }
    conf, acc = data[calibration_type]
    ece, bin_confs, bin_accs, bin_counts = compute_ece(conf, acc, n_bins)

    fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))

    # Reliability diagram (points)
    bin_centers = np.linspace(1/(2*n_bins), 1 - 1/(2*n_bins), n_bins)
    valid = bin_counts > 0

    axes[0].plot([0, 1], [0, 1], 'r--', linewidth=2, label='Parfait')
    axes[0].scatter(bin_centers[valid], bin_accs[valid], s=60,
                    c='steelblue', edgecolors='white', zorder=3, label='Accuracy')
    axes[0].plot(bin_centers[valid], bin_accs[valid], '-', color='steelblue', alpha=0.5)
    axes[0].set_xlabel('Confiance (bin moyen)')
    axes[0].set_ylabel('Accuracy (bin moyen)')
    axes[0].set_title(f'Reliability Diagram — ECE = {ece:.4f}')
    axes[0].set_xlim(0, 1)
    axes[0].set_ylim(0, 1)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Distribution par bin (barres — c'est un comptage)
    width = 1.0 / n_bins
    axes[1].bar(bin_centers, bin_counts, width=width*0.8,
                color='lightcoral', edgecolor='white')
    axes[1].set_xlabel('Confiance (bin)')
    axes[1].set_ylabel('Nombre de prédictions')
    axes[1].set_title(f'Distribution ({n_bins} bins, N={int(bin_counts.sum())})')
    axes[1].set_xlim(0, 1)

    plt.tight_layout()
    plt.show()

    # Détail par bin
    delta_header = '|Δ|'
    print(f"{'Bin':>6} {'Conf':>8} {'Acc':>8} {'Count':>8} {delta_header:>8}")
    for i in range(n_bins):
        if bin_counts[i] > 0:
            gap = abs(bin_accs[i] - bin_confs[i])
            print(f"{i+1:>6} {bin_confs[i]:>8.3f} {bin_accs[i]:>8.3f} {int(bin_counts[i]):>8} {gap:>8.3f}")

    print()
    print("Lecture des graphes :")
    print()
    print("Gauche — reliability diagram :")
    print("  Chaque point = un groupe de prédictions avec des confiances similaires")
    print("  Position verticale = accuracy moyenne du groupe")
    print("  Diagonale rouge = calibration parfaite")
    print("  ECE = distance moyenne pondérée entre les points et la diagonale")
    print()
    print("Droite — distribution des données par bin :")
    print("  Montre combien de prédictions tombent dans chaque bin")
    print("  Un bin avec peu de données → point peu fiable")
    print()
    print("À essayer :")
    print("  Bins = 3  → peu de points, estimation stable mais grossière")
    print("  Bins = 20 → beaucoup de points, mais certains bins vides")
    print("  Changer le type de calibration pour voir comment la courbe change")

widgets.interact(
    plot_ece_interactive,
    n_bins=widgets.IntSlider(value=10, min=3, max=20, step=1,
                              description='Bins', continuous_update=False),
    calibration_type=widgets.Dropdown(
        options=['Bien calibré', 'Sur-confiant', 'Sous-confiant'],
        value='Bien calibré', description='Type')
);

---
## 3. Reliability Diagram: step-by-step construction on the SR

### What is a reliability diagram?

It is the standard visualization for calibration: we check whether the agent's confidence matches its true performance. We already saw it with the weather data (Section 1) and the ECE (Section 2).

Here, we will **build it from scratch** on our ToyGrid, following the complete chain:

### The 3 steps (= the 3 grids)

| Step | What we compute | Grid |
|------|-----------------|------|
| **1. True error** | For each cell: $\|M(s,:) - M^*(s,:)\|$ | Left — where M is wrong |
| **2. Ground truth** | We set a threshold $\tau$ (median): correct if error < $\tau$ | Center — green/red grid |
| **3. What the agent believes** | The confidence $C(s)$ computed via $U \to C$ | Right — green/red grid |

**Compare center and right**: are the green zones in the same places? If so → the agent is well calibrated.

In [None]:
# Construction pas-à-pas sur les données SR (ToyGrid)
grid = ToyGrid.two_rooms()
gamma = 0.95
M_star = grid.true_sr(gamma)

# Agent partiellement entraîné (1000 steps, comme le notebook principal)
M = np.eye(grid.n_states)
traj = grid.random_walk(1000, seed=42)

# On accumule les erreurs seulement sur les derniers 500 pas
warmup = 500
error_sum = np.zeros(grid.n_states)
error_count = np.zeros(grid.n_states)
recent_norms = []

for t in range(len(traj) - 1):
    s, s_next = traj[t], traj[t+1]
    delta, M = grid.td_update(M, s, s_next, gamma, 0.1)
    if t >= warmup:
        norm = np.linalg.norm(delta)
        error_sum[s] += norm
        error_count[s] += 1
        recent_norms.append(norm)

# --- Étape 1 : erreur réelle par état ---
errors = np.array([np.linalg.norm(M[s] - M_star[s]) for s in range(grid.n_states)])
# Min-max : la plus petite erreur → 0 (jaune), la plus grande → 1 (rouge)
e_min, e_max = errors.min(), errors.max()
errors_norm = (errors - e_min) / (e_max - e_min + 1e-8)

# --- Étape 2 : accuracy binaire (seuil = médiane) ---
tau = np.median(errors)
accuracy = (errors < tau).astype(float)

# --- Étape 3 : U et confiance C ---
p99 = np.percentile(recent_norms, 99) if recent_norms else 1.0
mean_err = np.zeros(grid.n_states)
visited = error_count > 0
mean_err[visited] = error_sum[visited] / error_count[visited]
U = np.where(visited, np.clip(mean_err / max(p99, 1e-8), 0, 1), 0.8)

# On centre la sigmoïde sur la médiane de U pour avoir du contraste
theta_C = np.median(U)
C = 1 / (1 + np.exp(8 * (U - theta_C)))

# === Visualisation : 3 grilles ===
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

grid.plot(values=errors_norm, ax=axes[0], show_goal=False,
          title='Étape 1 : erreur de M', cmap='YlOrRd', vmin=0, vmax=1)

grid.plot(values=accuracy, ax=axes[1], show_goal=False,
          title='Étape 2 : la réalité',
          cmap='RdYlGn', vmin=0, vmax=1)

grid.plot(values=C, ax=axes[2], show_goal=False,
          title='Étape 3 : ce que l\'agent croit',
          cmap='RdYlGn', vmin=0, vmax=1)

plt.tight_layout()
plt.show()

# --- Explications ---
print("Lecture des 3 grilles :")
print()
print("GAUCHE — erreur réelle de M par case")
print("  Rouge foncé = M est très faux, jaune pâle = M est correct")
print()
print("CENTRE — la réalité (accuracy)")
print(f"  Seuil τ = médiane des erreurs = {tau:.3f}")
print(f"  Vert = M correct ({int(accuracy.sum())} états), rouge = incorrect ({int((1 - accuracy).sum())} états)")
print()
print("DROITE — ce que l'agent croit (confiance C)")
print(f"  Sigmoïde centrée sur médiane(U) = {theta_C:.2f}")
print("  Vert = confiant (U bas), rouge = doute (U haut)")
print()
print("Comparer CENTRE et DROITE :")
agree = ((accuracy == 1) & (C >= 0.5)) | ((accuracy == 0) & (C < 0.5))
n_agree = int(agree.sum())
print(f"  Concordance : {n_agree}/{grid.n_states} ({100*n_agree/grid.n_states:.0f}%)")
print()
if n_agree > grid.n_states * 0.7:
    print("  → Bonne calibration : les zones vertes/rouges se ressemblent")
else:
    print("  → Calibration imparfaite : les deux grilles ne se ressemblent pas assez")
print()
print("  Même case verte des 2 côtés → l'agent sait ce qu'il sait ✓")
print("  Rouge au centre, verte à droite → SUR-CONFIANT")
print("  Verte au centre, rouge à droite → SOUS-CONFIANT")

---
## 4. Metacognitive Index (MI)

### The question

The ECE measures whether overall confidence is well calibrated. But there is a finer question:

> Does the agent know **where** it is wrong?

Even with a correct ECE, the agent could have uniform confidence everywhere (e.g., $C = 0.5$ on all cells). That would be technically calibrated, but useless — it does not distinguish well-learned zones from poorly known zones.

### The idea

The **Metacognitive Index** (MI) measures whether the uncertainty $U(s)$ is **correlated** with the true SR errors.

$$MI = \rho_{Spearman}\left(U(s),\; \|M(s,:) - M^*(s,:)\|_2\right)$$

- We take the 21 states and their $U$ values
- We take the 21 true errors $\|M(s,:) - M^*(s,:)\|$
- We compute the rank correlation ($\rho_{Spearman}$)

### Why Spearman (rank) rather than Pearson?

The Spearman correlation does not measure whether $U$ is proportional to the error, but whether **the ordering is the same**: are the cells with the largest $U$ also those with the largest error? This is more robust to extreme values.

### How to read the result

- **MI > 0.5** → the agent knows well where it is wrong (PRISM target)
- **MI ≈ 0.3** → partial signal, but noisy
- **MI ≈ 0** → $U$ does not reflect the true errors — the metacognitive signal is useless

### What the widget shows

We can add **noise** to $U$ to see how the MI degrades. Without noise, $U$ perfectly reflects the errors → high MI. With a lot of noise, $U$ becomes random → MI drops to 0.

In [None]:
def plot_mi_interactive(noise_level):
    """Visualise l'effet du bruit sur le MI."""
    # Erreurs réelles
    true_errors = errors.copy()

    # U = erreurs + bruit
    noise = np.random.RandomState(42).randn(grid.n_states) * noise_level
    U_noisy = np.clip(true_errors / max(np.percentile(true_errors, 99), 1e-8) + noise, 0, 1)

    # MI
    rho, p_value = spearmanr(U_noisy, true_errors)

    fig, axes = plt.subplots(1, 3, figsize=(13, 4))

    # Scatter U vs erreur
    axes[0].scatter(true_errors, U_noisy, alpha=0.6, s=30, c='steelblue')
    axes[0].set_xlabel('Erreur réelle ||M(s)-M*(s)||')
    axes[0].set_ylabel('U(s) (avec bruit)')
    axes[0].set_title(f'MI = {rho:.3f} (p = {p_value:.4f})')
    axes[0].grid(True, alpha=0.3)

    # Heatmap erreur réelle (min-max pour gradient complet)
    te_min, te_max = true_errors.min(), true_errors.max()
    err_norm = (true_errors - te_min) / (te_max - te_min + 1e-8)
    grid.plot(values=err_norm, ax=axes[1], show_goal=False,
              title='Erreur réelle', cmap='YlOrRd', vmin=0, vmax=1)

    # Heatmap U bruité
    grid.plot(values=U_noisy, ax=axes[2], show_goal=False,
              title=f'U(s) (bruit={noise_level:.1f})', cmap='YlOrRd', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

    print("Lecture des graphes :")
    print()
    print("Gauche — chaque point = un état :")
    print("  Axe X = erreur réelle de M, axe Y = incertitude U estimée par l'agent")
    print("  Si les points suivent une diagonale → U reflète bien les vraies erreurs (MI élevé)")
    print("  Si les points sont dispersés → U est du bruit (MI faible)")
    print()
    print("Centre — erreur réelle sur la grille :")
    print("  Rouge = M est faux ici, jaune pâle = M est correct")
    print()
    print("Droite — U estimé par l'agent :")
    print("  Doit ressembler au centre si l'agent sait ce qu'il ne sait pas")
    print()

    if rho > 0.5:
        verdict = f"MI = {rho:.3f} → Forte corrélation : l'agent sait ce qu'il ne sait pas"
    elif rho > 0:
        verdict = f"MI = {rho:.3f} → Corrélation faible : signal métacognitif partiel"
    else:
        verdict = f"MI = {rho:.3f} → Pas de corrélation : le signal U est du bruit"
    print(verdict)
    print()
    print("À essayer :")
    print("  Bruit = 0   → MI maximal (U = erreur réelle, les deux cartes sont identiques)")
    print("  Bruit = 0.5 → MI diminue (les deux cartes divergent)")
    print("  Bruit = 1.0 → MI ≈ 0 (U est aléatoire, plus aucun lien avec l'erreur)")

widgets.interact(
    plot_mi_interactive,
    noise_level=widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.1,
                                     description='Bruit', continuous_update=False)
);

---
## 5. Defining accuracy in the SR context

### The problem

To build a reliability diagram, we need a **binary accuracy** per state: "is the agent's prediction correct at $s$, yes or no?"

But the SR is not a classifier — there is no obvious binary "correct answer". The error $\|M(s,:) - M^*(s,:)\|$ is a continuous number. How do we turn it into "correct / incorrect"?

### The method: median threshold

1. Compute the true error for each state: $error(s) = \|M(s,:) - M^*(s,:)\|$
2. Take the **median** of all errors as the threshold $\tau$
3. Define: $accuracy(s) = 1$ if $error(s) < \tau$, otherwise $0$

### Why the median?

- It guarantees ~50% of states are "correct" and ~50% "incorrect" → the reliability diagram has enough data on both sides
- A threshold that is too low (e.g., 10th percentile) would make almost everything "incorrect" → the diagram would be trivial
- A threshold that is too high (e.g., 90th percentile) would make almost everything "correct" → same problem

The widget below shows how the choice of percentile affects the diagram.

### Important note

The MI (Section 4) uses rank correlation, which does **not** depend on the threshold. This is an advantage: the MI is robust to the choice of $\tau$, unlike the ECE which depends on it.

In [None]:
def plot_threshold_effect(quantile):
    """Montre l'effet du choix de seuil sur la comparaison réalité vs croyance."""
    tau_q = np.percentile(errors, quantile)
    acc = (errors < tau_q).astype(float)

    fig, axes = plt.subplots(1, 3, figsize=(14, 4))

    # Gauche : distribution des erreurs + seuil
    axes[0].hist(errors, bins=15, color='steelblue', edgecolor='white', alpha=0.7)
    axes[0].axvline(x=tau_q, color='red', linewidth=2, linestyle='--',
                    label=f'seuil = {tau_q:.3f}')
    axes[0].set_xlabel('Erreur ||M(s)-M*(s)||')
    axes[0].set_ylabel("Nombre d'états")
    axes[0].set_title(f'Distribution (percentile {quantile})')
    axes[0].legend(fontsize=8)
    axes[0].grid(True, alpha=0.3)

    # Centre : la réalité (accuracy) — CHANGE avec le slider
    grid.plot(values=acc, ax=axes[1], show_goal=False,
              title=f'La réalité ({acc.mean()*100:.0f}% correct)',
              cmap='RdYlGn', vmin=0, vmax=1)

    # Droite : ce que l'agent croit (C) — NE CHANGE PAS
    grid.plot(values=C, ax=axes[2], show_goal=False,
              title='Ce que l\'agent croit (fixe)',
              cmap='RdYlGn', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

    # Concordance
    agree = ((acc == 1) & (C >= 0.5)) | ((acc == 0) & (C < 0.5))
    n_agree = int(agree.sum())
    n_disagree = grid.n_states - n_agree

    print("Lecture des graphes :")
    print()
    print("GAUCHE — distribution des erreurs de M :")
    print(f"  Ligne rouge = seuil τ (percentile {quantile} = {tau_q:.3f})")
    print("  Les états à gauche du seuil sont 'corrects', ceux à droite 'incorrects'")
    print()
    print("CENTRE — la réalité (change avec le slider) :")
    print(f"  Vert = correct ({int(acc.sum())} états), rouge = incorrect ({int((1-acc).sum())} états)")
    print()
    print("DROITE — ce que l'agent croit (fixe, ne change pas) :")
    print("  Vert = confiant, rouge = pas confiant")
    print()
    print(f"Concordance centre ↔ droite : {n_agree}/{grid.n_states} ({100*n_agree/grid.n_states:.0f}%)")
    print()
    print("À essayer :")
    print("  Bougez le slider : seul le CENTRE change, la DROITE reste fixe.")
    print("  Percentile 50 → ~50/50, bon équilibre pour juger la calibration")
    print("  Percentile 80 → presque tout vert au centre, concordance artificielle")
    print("  Percentile 20 → presque tout rouge au centre, concordance faible")
    print("  → Le MI (Section 4) ne dépend pas de ce choix, c'est son avantage")

widgets.interact(
    plot_threshold_effect,
    quantile=widgets.IntSlider(value=50, min=10, max=90, step=5,
                                description='Percentile τ', continuous_update=False)
);

---
## 6. Hybrid demo: calibration on real data

Uses PRISM code to compute ECE and MI on a real FourRooms agent.

**Requires**: `pip install -e .` + MiniGrid

In [None]:
try:
    import minigrid
    import gymnasium as gym
    from prism.agent.prism_agent import PRISMAgent
    from prism.env.state_mapper import StateMapper
    from prism.env.dynamics_wrapper import DynamicsWrapper
    from prism.analysis.calibration import (
        sr_errors, sr_accuracies,
        expected_calibration_error,
        metacognitive_index,
        reliability_diagram_data,
        plot_reliability_diagram
    )

    # Setup
    env = DynamicsWrapper(gym.make("MiniGrid-FourRooms-v0", max_steps=500))
    agent = PRISMAgent(env)

    # Entra\u00eener 200 \u00e9pisodes
    for ep in range(200):
        agent.train_episode()
        if (ep + 1) % 50 == 0:
            print(f"Episode {ep+1}/200")

    # Calculer M*
    T = env.compute_transition_matrix(agent.mapper)
    M_star = np.linalg.inv(np.eye(agent.mapper.n_states) - 0.95 * T)

    # M\u00e9triques
    errors = sr_errors(agent.sr.M, M_star)
    accuracies = sr_accuracies(errors)
    confidences = agent.meta.all_confidences()
    uncertainties = agent.meta.all_uncertainties()

    ece = expected_calibration_error(confidences, accuracies)
    rho, p_val = metacognitive_index(uncertainties, agent.sr.M, M_star)

    print(f"\nECE = {ece:.4f} (cible < 0.15)")
    print(f"MI = {rho:.4f} (p = {p_val:.4f}, cible > 0.5)")

    # Reliability diagram
    fig = plot_reliability_diagram(confidences, accuracies, label='PRISM')
    plt.show()

    env.close()

except ImportError as e:
    print(f"D\u00e9pendance manquante : {e}")
    print("Pour ex\u00e9cuter cette cellule : pip install minigrid gymnasium")
    print("Puis : cd PRISM && pip install -e .")

---
## Summary

| Metric | Formula | Interpretation | Target |
|--------|---------|----------------|--------|
| ECE | $\sum_b \frac{|B_b|}{N} |acc_b - conf_b|$ | Mean confidence/accuracy gap | < 0.15 |
| MI | $\rho_{Spearman}(U, error)$ | The agent knows where it is wrong | > 0.5 |
| Accuracy | $\mathbb{1}(\|M-M^*\| < \tau)$ | Binary "correct" prediction | ~50% baseline |

$\leftarrow$ [Back to the main notebook](00_prism_concepts.ipynb)