# PRISM — Mathematical and Algorithmic Concepts

This notebook explains the mathematical foundations of the PRISM project in an interactive way.
Each section introduces a concept with its formula, then offers a widget to explore it visually.

**Dependencies:** numpy, matplotlib, ipywidgets (no need for MiniGrid).

**Appendices:**
- `00a_spectral_deep_dive.ipynb` — Eigendecomposition of M
- `00b_calibration_methods.ipynb` — Calibration metrics (ECE, MI)


---
## Section 0 — Setup


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown
import sys, os

# Ensure prism package is importable
sys.path.insert(0, os.path.abspath('..'))
from prism.pedagogy.toy_grid import ToyGrid

%matplotlib inline
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

In [None]:
# Grille de référence : deux pièces connectées par un passage
grid = ToyGrid.two_rooms()
print(f"Grille {grid.rows}×{grid.cols}, {grid.n_states} états accessibles, 4 murs")
print(f"Goal : {grid.goal}")

fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
grid.plot(ax=axes[0], title="Structure de la grille")

# Numéroter les états sur fond neutre
grid.plot(ax=axes[1], title="Index des états")
for i, pos in grid.idx_to_pos.items():
    axes[1].text(pos[1], pos[0], str(i), ha='center', va='center',
                 fontsize=7, fontweight='bold', color='steelblue')

plt.tight_layout()
plt.show()

---
## Section 1 — The Successor Representation (SR)

### Intuition

Imagine an agent randomly walking around the grid. At each step, it can move in 4 directions (up, down, left, right).

The **Successor Representation** ($M$) is a large table that answers this question for **every pair of cells**:

> "If I start from cell $s$, **how many times** will I visit cell $s'$ in the future?"

Concretely, $M$ is a **matrix** (a two-dimensional table):
- **Row** = starting cell $s$
- **Column** = destination cell $s'$
- **Value** $M(s, s')$ = expected visitation frequency of $s'$ starting from $s$

Our grid has 21 accessible cells, so $M$ is a 21×21 = 441-value table. Each row is a "heat map" that says: "from this starting cell, here are the cells I will visit the most".

For example, the row of $M$ for the top-left corner will show:
- **Nearby** cells with high values (visited often)
- **Distant** cells or those behind a wall with low values (rarely reached)

### Why this is useful

$M$ allows the agent to **evaluate any goal instantly**.

Example: food is placed in cell 15. The agent wants to know "is cell $s$ a good starting point for reaching the food?". It simply reads $M(s, 15)$: if the value is high, it will pass through cell 15 often, so it is a good starting point.

And if the food **changes location** (say cell 3)? The agent does not need to relearn everything — it reads $M(s, 3)$ instead. The table $M$ stays the same; only the goal changes.

This is the strength of the SR: it separates **"where I can go"** ($M$, the structure) from **"where I want to go"** (the reward). We will formalize this idea in Section 3 with $V = M \cdot R$.

### Formula

$$M(s, s') = \mathbb{E}\left[\sum_{t=0}^{\infty} \gamma^t \, \mathbb{1}(s_t = s') \;\middle|\; s_0 = s\right]$$

Let us break it down term by term:
- $s$: the starting cell
- $s'$: any cell whose visitation frequency we want to know
- $\mathbb{1}(s_t = s')$: equals 1 if the agent is at $s'$ at time $t$, 0 otherwise
- $\gamma^t$: a weight that **decreases with time** ($\gamma < 1$), so distant visits count less
- $\sum_{t=0}^{\infty}$: we sum over the entire future
- $\mathbb{E}[...]$: we average over all possible trajectories

The parameter $\gamma$ controls **how far ahead the agent looks**:
- $\gamma = 0.5$ → the agent only considers its immediate neighbors (horizon ≈ 2 steps)
- $\gamma = 0.95$ → the agent anticipates ~20 steps into the future
- $\gamma = 0.99$ → the agent sees very far (horizon ≈ 100 steps)

### The transition matrix $T$

To compute $M$, we need to know **how the agent moves**. This is what the transition matrix $T$ describes:

$$T_{ij} = P(s' = j \mid s = i)$$

In plain terms: $T_{ij}$ is the **probability of going from cell $i$ to cell $j$** in one step.

In our grid, the policy is uniform (4 directions equally probable), so:
- If $j$ is an accessible neighbor of $i$: $T_{ij} = 0.25$
- If $j$ is a wall or not adjacent: $T_{ij} = 0$
- If the agent tries to go into a wall, it stays in place: this increases $T_{ii}$

$T$ summarizes the entire **environment structure** + the **agent's policy**.

### Computing $M$

We can decompose $M$ by summing the contributions of each time step:

$$M = \underbrace{I}_{t=0} + \underbrace{\gamma \, T}_{t=1} + \underbrace{\gamma^2 T^2}_{t=2} + \underbrace{\gamma^3 T^3}_{t=3} + \ldots$$

Each term has a concrete meaning:

| Term | Meaning |
|------|--------|
| $I$ | At $t=0$, the agent is at its starting cell (contribution = 1) |
| $\gamma \, T$ | At $t=1$, it has taken one step → $T$ gives the probabilities of reaching each cell |
| $\gamma^2 T^2$ | At $t=2$, it has taken two steps → $T^2$ gives the probabilities in 2 steps |
| $\gamma^t T^t$ | At any $t$, weighted by $\gamma^t$ (distant steps count less) |

This infinite sum converges (because $\gamma < 1$) and equals:

$$M = (I - \gamma T)^{-1}$$

The plot below shows **one row** of this table: for a chosen starting state, the visitation frequency of each cell.


In [None]:
def plot_sr_interactive(gamma, state):
    """Affiche M*[s, :] pour un état de départ donné."""
    M_star = grid.true_sr(gamma)
    horizon = 1 / (1 - gamma)

    fig, ax = plt.subplots(1, 1, figsize=(5, 4))

    row = M_star[state]
    row_norm = row / row.max()
    grid.plot(values=row_norm, ax=ax,
              title=f"Occupancy future depuis s={state} (γ={gamma:.2f})",
              cmap='plasma', vmin=0, vmax=1)
    pos = grid.idx_to_pos[state]
    ax.plot(pos[1], pos[0], 'wo', markersize=12, zorder=5,
            markeredgecolor='black', markeredgewidth=1.5)

    plt.tight_layout()
    plt.show()
    print(f"\"Si l'agent part de ⚪, quelles cases va-t-il visiter ?\"")
    print(f"Jaune = souvent, violet = rarement, gris = mur")
    print(f"γ = {gamma:.2f} → horizon ≈ {horizon:.0f} steps")

widgets.interact(
    plot_sr_interactive,
    gamma=widgets.FloatSlider(value=0.95, min=0.5, max=0.99, step=0.01,
                              description='γ', continuous_update=False),
    state=widgets.IntSlider(value=0, min=0, max=grid.n_states-1,
                            description='s (départ)')
);

### Reading the plot

> "If the agent starts from ⚪, which cells will it visit in the future?"

- **Yellow** = cell visited often (close to ⚪ or easy to access)
- **Purple** = cell rarely reached (far away, or on the other side of the wall)
- **Gray** = wall (inaccessible)

**Try this:**
- Move **s** → the yellow pattern follows the starting state
- Increase **γ toward 0.99** → the agent "sees" far, yellow everywhere
- Decrease **γ toward 0.5** → the agent only sees its immediate neighbors


---
## Section 2 — TD(0) Learning of M

### The problem

In Section 1, we computed $M = (I - \gamma T)^{-1}$. But this computation requires knowing $T$ — all the transition probabilities of the environment.

In practice, **the agent does not know $T$**. It does not know in advance where walls, passages, etc. lead. It must discover the structure by **moving around and observing**.

It is as if you arrived in an unknown city without a map: you cannot compute the best path in advance. You must explore and **build your mental map as you go**.

### Why the agent needs a good $M$

Recall from Section 1: $M$ allows evaluating any goal (food at cell 15 → read $M(s, 15)$). But this only works **if $M$ correctly reflects the environment structure**.

A wrong $M$ would give bad evaluations: the agent might believe a cell is easy to reach when there is actually a wall between the two.

We denote:
- $M^*$ = the result of the formula $(I - \gamma T)^{-1}$ (Section 1). This is the **target** — the perfect $M$ one would obtain if $T$ were known.
- $M$ = what the agent has **learned so far**. Initially wrong, but with experience $M$ approaches $M^*$.

The goal of learning is that $M \to M^*$: the agent's mental map becomes as good as if it had had the complete map from the start.

### How to learn: Temporal Difference (TD)

At each step, the agent is at $s$ and transitions to $s'$. It can compare **what it predicts** with **what it observes**:

- **Predicted**: $M(s, :)$ — its current map of frequencies from $s$
- **Observed**: $e(s') + \gamma \cdot M(s', :)$ — "I am at $s'$ now (= $e(s')$), and from $s'$ I predict $M(s', :)$ for the rest"

The difference is the **TD error**:

$$\delta_M(s) = \underbrace{e(s') + \gamma \cdot M(s', :)}_{\text{observed}} - \underbrace{M(s, :)}_{\text{predicted}}$$

The agent corrects its prediction in the direction of the error:

$$M(s, :) \leftarrow M(s, :) + \alpha_M \cdot \delta_M(s)$$

### The learning rate $\alpha_M$

$\alpha_M$ controls **how much the agent corrects its prediction** after each observation.

**Analogy**: you think it takes 30 minutes to get to work. Today it took 40 minutes. How do you update your estimate?

| $\alpha_M$ | Correction | New estimate | Behavior |
|-----------|-----------|-------------------|-------------|
| 0.01 | 1% of the error | 30 + 0.01 × (40-30) = **30.1 min** | Very cautious, slow to adapt |
| 0.1 | 10% of the error | 30 + 0.1 × (40-30) = **31 min** | Balanced |
| 1.0 | 100% of the error | 30 + 1.0 × (40-30) = **40 min** | Blindly trusts the latest observation |

An $\alpha_M$ too large causes the agent to oscillate (it overreacts to each observation). An $\alpha_M$ too small makes it very slow to learn.

$M$ is initialized to $I$ (the identity matrix: each state predicts only itself, no knowledge of the structure).


In [None]:
# M* = le vrai M calculé en Section 1 (sert de référence)
M_star_ref = grid.true_sr(0.95)

def td_learning_demo(alpha_M, n_steps, state):
    """Simule n_steps de TD(0) et montre la convergence."""
    gamma = 0.95
    M = np.eye(grid.n_states)  # Init à I
    errors = []

    # Random walk
    traj = grid.random_walk(n_steps, seed=42)

    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t + 1]
        delta, M = grid.td_update(M, s, s_next, gamma, alpha_M)
        if t % 10 == 0:
            err = np.linalg.norm(M - M_star_ref, 'fro') / grid.n_states
            errors.append((t, err))

    fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

    # Gauche : courbe de convergence
    steps, errs = zip(*errors)
    axes[0].plot(steps, errs, 'b-', linewidth=1.5)
    axes[0].set_xlabel('Steps (expérience)')
    axes[0].set_ylabel('Erreur M vs M*')
    axes[0].set_title(f'Convergence M → M*')
    axes[0].set_ylim(bottom=0)
    axes[0].grid(True, alpha=0.3)

    # Droite : M appris pour l'état choisi (normalisé, comme Section 1)
    row = M[state]
    row_norm = row / max(row.max(), 1e-8)
    grid.plot(values=row_norm, ax=axes[1],
              title=f"M appris depuis s={state}", cmap='plasma', vmin=0, vmax=1)
    pos = grid.idx_to_pos[state]
    axes[1].plot(pos[1], pos[0], 'wo', markersize=12, zorder=5,
                 markeredgecolor='black', markeredgewidth=1.5)

    plt.tight_layout()
    plt.show()
    print(f"Erreur finale : {errs[-1]:.4f}")
    print()
    print("Gauche : écart entre le M appris et le vrai M (Section 1) — descend avec l'expérience")
    print("Droite : avec assez de steps, cette carte converge vers celle de la Section 1 (= M*)")
    print()
    print("À essayer :")
    print("  α petit (0.01) → apprentissage lent mais stable")
    print("  α grand (0.5)  → apprentissage rapide mais instable (courbe qui oscille)")

widgets.interact(
    td_learning_demo,
    alpha_M=widgets.FloatSlider(value=0.1, min=0.01, max=0.5, step=0.01,
                                description='α_M', continuous_update=False),
    n_steps=widgets.IntSlider(value=500, min=50, max=3000, step=50,
                               description='Steps', continuous_update=False),
    state=widgets.IntSlider(value=0, min=0, max=grid.n_states-1,
                            description='s (départ)')
);

---
## Section 3 — Value Decomposition V = M · R

### The question

The agent knows the structure of the environment (via $M$). Now it wants to **make decisions**: "which cell is the best for me right now?"

For this, it needs two pieces of information:
1. **Where is the reward?** → this is the vector $R$
2. **Can I reach it easily?** → this is already in $M$

### The reward vector $R$

$R$ is a simple vector: one value per cell.

$$R(s) = \text{reward obtained upon arriving at } s$$

In our grid, there is food at a single cell (the "goal"):
- $R(\text{goal}) = 1$
- $R(\text{everywhere else}) = 0$

### The value $V(s)$

$V(s)$ answers the question: **"how much reward will I accumulate starting from $s$?"**

- High $V$ = good position (close to the food, easy access)
- Low $V$ = bad position (far away, or blocked by a wall)

### The formula: V = M · R

$$V(s) = \sum_{s'} M(s, s') \cdot R(s')$$

In plain terms: the value of $s$ = sum over all cells of (visitation frequency × reward at that cell).

This is a simple matrix-vector product: $V = M \cdot R$.

### Why this is powerful

- **If the goal moves** (R changes): simply recompute $V = M \cdot R$. **M stays the same** — the environment structure has not changed.
- **If a wall appears** (structure changes): then $M$ must be relearned (→ Section 8).


In [None]:
def plot_v_decomposition(goal_row, goal_col, gamma):
    """Montre V = M · R pour différentes positions du goal."""
    goal = (goal_row, goal_col)
    if goal in grid.walls or goal_row >= grid.rows or goal_col >= grid.cols:
        print(f"Position ({goal_row}, {goal_col}) est un mur ou hors grille.")
        return

    M_star = grid.true_sr(gamma)
    R = grid.reward_vector(goal)
    V = M_star @ R

    # Normaliser pour échelles fixes
    row0 = M_star[0]
    row0_norm = row0 / max(row0.max(), 1e-8)
    V_norm = V / max(V.max(), 1e-8)

    fig, axes = plt.subplots(1, 2, figsize=(9, 3.5))

    # Gauche : M[0, :] + goal
    grid.plot(values=row0_norm, ax=axes[0], show_goal=False,
              title='M depuis s=0 (ne change pas)', cmap='plasma', vmin=0, vmax=1)
    axes[0].plot(goal[1], goal[0], 'r*', markersize=18, zorder=5)

    # Droite : V = M · R + goal — même colormap que M
    grid.plot(values=V_norm, ax=axes[1], show_goal=False,
              title='V = M · R', cmap='plasma', vmin=0, vmax=1)
    axes[1].plot(goal[1], goal[0], 'r*', markersize=18, zorder=5)

    plt.tight_layout()
    plt.show()
    print("★ rouge = position de la nourriture (déplacez-la avec les sliders)")
    print("Gauche : M ne bouge pas — la structure de l'environnement ne change pas")
    print("Droite : V change — la valeur des cases dépend de la position du goal")
    print()
    print("→ C'est la force de la SR : un seul M sert pour n'importe quel objectif.")

widgets.interact(
    plot_v_decomposition,
    goal_row=widgets.IntSlider(value=0, min=0, max=grid.rows-1,
                                description='Goal ligne'),
    goal_col=widgets.IntSlider(value=4, min=0, max=grid.cols-1,
                                description='Goal colonne'),
    gamma=widgets.FloatSlider(value=0.95, min=0.5, max=0.99, step=0.01,
                              description='γ', continuous_update=False)
);

---
## Section 4 — Prediction Error $\|\delta_M\|$

### Motivation

In Section 2, we saw that the agent learns $M$ by correcting its errors ($\delta_M$). We know that the overall error decreases with experience.

But the agent needs more than that: it needs to know **where** its predictions are still poor. "Do I know this area of the grid well, or am I still making frequent mistakes here?"

The TD error $\delta_M(s)$ answers exactly this question:
- **Large error** at $s$ → "I was wrong arriving here, I know this area poorly"
- **Small error** at $s$ → "no surprise, my prediction was good"

This is the basic signal for the agent's **metacognition** (Sections 5-9): its ability to know what it knows and what it does not know.

### From vector to scalar

$\delta_M(s)$ is a **vector** of 21 values (one per cell). To get a single number summarizing "did the agent make a mistake at $s$?", we take the **norm** (the "length" of the vector):

$$\|\delta_M(s)\| = \sqrt{\sum_{s'} \delta_M(s, s')^2}$$

- $\|\delta_M\| = 0$ → perfect prediction
- $\|\delta_M\|$ large → big surprise

### Normalization by the 99th percentile

Raw errors can range from 0.003 to 47 depending on $\gamma$ and grid size. To make them comparable and bring them into $[0, 1]$, we need a **reference**: "what counts as a large error?"

We use the **99th percentile** (p99) of all errors observed so far:

$$\delta_{norm} = \min\left(\frac{\|\delta_M\|}{\text{p99}},\; 1\right)$$

**Why p99?**

Imagine the observed errors sorted from smallest to largest. The p99 is the value below which 99% of errors fall. By dividing by p99:
- A "normal" error gives $\delta_{norm}$ between 0 and 1
- Only the top 1% most extreme errors are clipped to 1

**Why not the maximum?** The max is often an isolated spike (for example, the very first error of the agent that knows nothing yet). Dividing by that spike would squash all other errors toward 0, making the map unreadable.

**Why not the mean?** The mean would give normalized values > 1 for any error above the mean — half the cells would be clipped to 1.

The p99 is a good compromise: it ignores extreme spikes while preserving contrast between cells.


In [None]:
def plot_td_errors(n_steps, alpha_M):
    """Montre les erreurs TD par état au cours de l'apprentissage."""
    gamma = 0.95
    M = np.eye(grid.n_states)

    # Accumuler les erreurs par état
    error_sum = np.zeros(grid.n_states)
    error_count = np.zeros(grid.n_states)
    all_norms = []

    traj = grid.random_walk(n_steps, seed=42)

    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t + 1]
        delta, M = grid.td_update(M, s, s_next, gamma, alpha_M)
        norm = np.linalg.norm(delta)
        error_sum[s] += norm
        error_count[s] += 1
        all_norms.append(norm)

    # Erreur moyenne par état
    mean_error = np.zeros(grid.n_states)
    visited = error_count > 0
    mean_error[visited] = error_sum[visited] / error_count[visited]

    # Normalisation p99
    p99 = np.percentile(all_norms, 99) if all_norms else 1.0
    normalized = np.clip(mean_error / max(p99, 1e-8), 0, 1)

    fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

    # Erreur normalisée p99
    grid.plot(values=normalized, ax=axes[0], show_goal=False,
              title='Erreur normalisée par case', cmap='YlOrRd', vmin=0, vmax=1)

    # Nombre de visites
    grid.plot(values=error_count, ax=axes[1], show_goal=False,
              title='Nombre de visites', cmap='Blues')

    plt.tight_layout()
    plt.show()
    print("Gauche : rouge = l'agent se trompe souvent ici (zone mal connue)")
    print("         jaune pâle = peu d'erreurs (zone bien apprise)")
    print("Droite : bleu foncé = case souvent visitée, clair = rarement")
    print()
    print(f"p99 = {p99:.4f}, états non visités : {(~visited).sum()}/{grid.n_states}")
    print()
    print("→ Les cases peu visitées ont souvent plus d'erreurs (l'agent ne les connaît pas)")
    print("→ Augmentez les steps : les erreurs diminuent partout")

widgets.interact(
    plot_td_errors,
    n_steps=widgets.IntSlider(value=500, min=50, max=3000, step=50,
                               description='Steps', continuous_update=False),
    alpha_M=widgets.FloatSlider(value=0.1, min=0.01, max=0.5, step=0.01,
                                description='α_M (cf. Sec. 2)', continuous_update=False)
);

---
## Section 5 — Uncertainty Map U(s)

### The need

In Section 4, we saw that the error $\|\delta_M(s)\|$ measures "did the agent make a mistake at this instant, at this cell". But this error is **noisy**: a single observation can produce an error spike just by bad luck.

The agent needs a more **stable** measure: "do I know this area well, **in general**?" — not just "did I make a mistake this time".

This is the role of $U(s)$: transforming punctual errors into a **durable uncertainty map**, cell by cell.

### Intuition

Imagine you are visiting a neighborhood in a city:
- The **first time**, you often take the wrong turn → high uncertainty
- After **a few visits**, you make fewer mistakes → uncertainty decreases
- After **many visits**, you know the neighborhood well → low and stable uncertainty

$U(s)$ reproduces exactly this behavior. For each cell $s$, the agent stores a **history of its recent errors** and averages them.

### Three regimes

$U(s)$ depends on how many times the agent has visited cell $s$:

| Regime | Condition | What the agent does | Formula |
|---------|-----------|-------------------|----------|
| **Never visited** | $visits(s) = 0$ | "I have never been here → maximum uncertainty" | $U(s) = U_{prior}$ |
| **Few visits** | $0 < visits(s) < K$ | "I have a few observations but not enough for a reliable average" | $U(s) = U_{prior} \cdot decay^{visits(s)}$ |
| **Enough visits** | $visits(s) \geq K$ | "I have enough data → I average my last $K$ errors" | $U(s) = \text{mean}(\text{buffer}(s))$ |

The parameters:
- **$U_{prior}$** (default 0.8): initial uncertainty. "Before seeing anything, I am 80% uncertain."
- **$decay$** (default 0.85): speed of the initial descent. Each visit multiplies the uncertainty by $decay$.
- **$K$** (default 20): buffer size. After $K$ visits, the agent switches to the average of its actual errors.

### Why 3 regimes?

- **Regime 1** (never visited): without data, we are cautious → high $U$.
- **Regime 2** (cold-start): we do not yet have $K$ errors to compute a reliable average. The formula $U_{prior} \cdot decay^{visits}$ decreases $U$ progressively, like a fading prior.
- **Regime 3** (converged): the buffer is full → the average reflects actual errors. If the agent knows the area well, errors are small → low $U$. If the area has changed, errors rise → $U$ rises too.

### How to choose the buffer size $K$?

The buffer is a **sliding window**: the agent keeps its last $K$ errors for each cell and averages them. Older errors (beyond $K$) are forgotten.

The choice of $K$ is a tradeoff between **stability** and **reactivity**:

- **Small $K$** (e.g. 5): the average covers only 5 errors → it is **noisy** (a single spike pushes it up), but it **reacts quickly** if the area changes.
- **Large $K$** (e.g. 50): the average is smoothed over 50 errors → it is **stable**, but it takes time to rise if the environment changes (the old small errors "dilute" the new large ones).

**In practice:** we choose $K$ based on how often the environment can change.
- **Stable** environment (walls never move) → large $K$ (30–50) for a well-smoothed $U$.
- **Changing** environment (obstacles that appear/disappear) → small $K$ (10–15) to quickly detect that "I no longer know this area".

The plot below shows $U(s)$ as a function of the number of visits for a given cell.


In [None]:
def plot_uncertainty_regimes(U_prior, decay, K):
    """Visualise les 3 régimes de U(s) en fonction du nombre de visites."""
    visits = np.arange(0, 60)

    # Simuler des erreurs de buffer décroissantes (convergence réaliste)
    U_values = []
    buffer = []
    for v in visits:
        if v == 0:
            U_values.append(U_prior)
        elif v < K:
            U_values.append(U_prior * (decay ** v))
        else:
            # Simuler buffer : erreurs décroissantes + bruit
            simulated_error = max(0.01, 0.5 * np.exp(-0.05 * v) + 0.02 * np.sin(v))
            buffer.append(simulated_error)
            if len(buffer) > K:
                buffer = buffer[-K:]
            U_values.append(np.mean(buffer))

    U_values = np.array(U_values)

    fig, ax = plt.subplots(1, 1, figsize=(9, 4))

    # Zones colorées pour les régimes
    ax.axvspan(0, 0.5, alpha=0.15, color='red', label='Non visité')
    ax.axvspan(0.5, K, alpha=0.15, color='orange', label=f'Cold-start (< K={K})')
    ax.axvspan(K, 60, alpha=0.15, color='green', label=f'Convergé (≥ K={K})')

    ax.plot(visits, U_values, 'b-', linewidth=2, label='U(s)')
    ax.axhline(y=U_prior, color='red', linestyle='--', alpha=0.5, label=f'U_prior={U_prior}')
    ax.axvline(x=K, color='gray', linestyle=':', alpha=0.7)

    ax.set_xlabel('Nombre de visites')
    ax.set_ylabel('U(s)')
    ax.set_title('Incertitude U(s) : 3 régimes')
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc='upper right', fontsize=8)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("Lecture du graphe :")
    print(f"  Axe horizontal = nombre de fois que l'agent a visité cette case")
    print(f"  Axe vertical   = incertitude U(s) (1 = ne sait rien, 0 = connaît parfaitement)")
    print(f"  Courbe bleue   = évolution de U au fil des visites")
    print()
    print("Les 3 zones colorées :")
    print(f"  Rouge  (0 visite)  : U = {U_prior} — l'agent n'est jamais venu ici, incertitude maximale")
    print(f"  Orange (1 à {K-1} visites) : U descend vite — chaque visite réduit l'incertitude")
    print(f"    Exemple à 5 visites : U = {U_prior} × {decay}⁵ = {U_prior * decay**5:.3f}")
    print(f"  Vert   (≥ {K} visites)  : U = moyenne des {K} dernières erreurs réelles")
    print(f"    → U se stabilise à un niveau bas si l'agent connaît bien la zone")
    print()
    print("À essayer :")
    print("  U_prior haut (0.95) → l'agent démarre très incertain, la descente est plus longue")
    print("  decay bas (0.7)     → chaque visite réduit beaucoup l'incertitude (descente rapide)")
    print("  K grand (50)        → il faut plus de visites avant de basculer sur les erreurs réelles")

widgets.interact(
    plot_uncertainty_regimes,
    U_prior=widgets.FloatSlider(value=0.8, min=0.5, max=0.95, step=0.05,
                                 description='U_prior', continuous_update=False),
    decay=widgets.FloatSlider(value=0.85, min=0.7, max=0.95, step=0.05,
                               description='decay', continuous_update=False),
    K=widgets.IntSlider(value=20, min=5, max=50, step=5,
                         description='K (buffer)')
);

---
## Section 6 — Confidence Signal C(s)

### Why not use U directly?

In Section 5, we constructed $U(s) \in [0, 1]$: an uncertainty map per cell. We could use it directly for decision-making ("if $U > 0.5$, explore"). But $U$ has two problems:

1. **No sharp threshold.** $U = 0.35$ — is that uncertain or not? The boundary between "I know" and "I don't know" is fuzzy. Yet the agent needs to **decide**: "do I trust my prediction here, yes or no?"

2. **Sensitivity to parameters.** $U$ depends on $U_{prior}$, $decay$, $K$... Its raw values vary widely depending on the configuration. We want a confidence signal whose scale is always the same: 0 = no confidence, 1 = full confidence.

### The idea

We transform $U$ into a **confidence** signal $C(s)$ with two properties:
- **Inverted**: high uncertainty → low confidence (and vice versa)
- **Sharp decision**: the transition between "confident" and "not confident" is rapid, not gradual

For this, we use a **sigmoid** — an S-shaped function that squashes values toward 0 or 1:

$$C(s) = \frac{1}{1 + \exp\left(\beta \cdot (U(s) - \theta_C)\right)}$$

### The parameters

| Parameter | What it controls | Default |
|-----------|------------------|--------|
| $\theta_C$ | The **center**: at what uncertainty level confidence equals exactly 0.5. "Below $\theta_C$, I am rather confident. Above, rather not." | 0.3 |
| $\beta$ | The **slope**: how abrupt the transition is. Large $\beta$ = nearly instantaneous switch between 0 and 1. Small $\beta$ = smooth and gradual transition. | 10 |

**Concrete examples** (with $\beta = 10$, $\theta_C = 0.3$):
- $U = 0.1$ (few errors) → $C \approx 0.88$ — the agent is **confident**
- $U = 0.3$ (moderate errors) → $C = 0.50$ — the agent **hesitates**
- $U = 0.6$ (many errors) → $C \approx 0.05$ — the agent **does not trust** its prediction


In [None]:
def plot_confidence_sigmoid(beta, theta_C):
    """Visualise la sigmoïde C(U) et la heatmap C sur la grille."""
    U_range = np.linspace(0, 1, 200)
    C_range = 1 / (1 + np.exp(beta * (U_range - theta_C)))

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # Courbe sigmoïde
    axes[0].plot(U_range, C_range, 'b-', linewidth=2)
    axes[0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    axes[0].axvline(x=theta_C, color='red', linestyle='--', alpha=0.5,
                    label=f'θ_C = {theta_C}')
    axes[0].fill_between(U_range, C_range, alpha=0.1)
    axes[0].set_xlabel('U(s) — Incertitude')
    axes[0].set_ylabel('C(s) — Confiance')
    axes[0].set_title(f'Sigmoïde : β={beta}, θ_C={theta_C}')
    axes[0].set_xlim(0, 1)
    axes[0].set_ylim(0, 1)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Simuler U après apprentissage partiel (1000 steps)
    # Assez pour que les zones fréquentées soient bien apprises (vert)
    # mais pas assez pour que tout soit appris (contraste vert/rouge)
    gamma = 0.95
    M_star = grid.true_sr(gamma)
    M_partial = np.eye(grid.n_states)
    n_sim = 1000
    traj = grid.random_walk(n_sim, seed=42)

    # Accumuler erreurs par état (comme Section 4)
    error_sum = np.zeros(grid.n_states)
    error_count = np.zeros(grid.n_states)
    all_norms = []
    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t+1]
        delta, M_partial = grid.td_update(M_partial, s, s_next, gamma, 0.1)
        norm = np.linalg.norm(delta)
        error_sum[s] += norm
        error_count[s] += 1
        all_norms.append(norm)

    # U = erreur moyenne normalisée par état
    p99 = np.percentile(all_norms, 99) if all_norms else 1.0
    mean_err = np.zeros(grid.n_states)
    visited = error_count > 0
    mean_err[visited] = error_sum[visited] / error_count[visited]
    U_sim = np.where(visited, np.clip(mean_err / max(p99, 1e-8), 0, 1), 0.8)

    C_sim = 1 / (1 + np.exp(beta * (U_sim - theta_C)))

    grid.plot(values=C_sim, ax=axes[1], show_goal=False,
              title='C(s) sur la grille (après 1000 steps)', cmap='RdYlGn', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

    n_green = (C_sim > 0.5).sum()
    n_red = (C_sim <= 0.5).sum()
    C_at_0 = 1/(1+np.exp(beta*(0-theta_C)))
    C_at_1 = 1/(1+np.exp(beta*(1-theta_C)))

    print("Lecture des graphes :")
    print()
    print("Gauche — la sigmoïde C(U) :")
    print(f"  Axe horizontal = incertitude U (0 = sûr, 1 = ne sait rien)")
    print(f"  Axe vertical   = confiance C (0 = pas confiant, 1 = confiant)")
    print(f"  Ligne rouge pointillée = θ_C = {theta_C} → à ce U, la confiance vaut exactement 0.5")
    print(f"  À gauche de θ_C : C monte vers {C_at_0:.2f} (confiant)")
    print(f"  À droite de θ_C : C descend vers {C_at_1:.3f} (pas confiant)")
    print()
    print("Droite — la confiance sur la grille :")
    print(f"  Vert  = confiance élevée (zone bien apprise) — {n_green} cases")
    print(f"  Rouge = confiance faible (zone mal connue) — {n_red} cases")
    print(f"  Gris  = mur")
    print()
    print("À essayer :")
    print("  β petit (2)  → transition douce, beaucoup de cases en jaune (intermédiaire)")
    print("  β grand (25) → transition brutale, les cases sont soit vertes soit rouges")
    print("  θ_C bas (0.1) → l'agent exige très peu d'erreurs pour être confiant → plus de rouge")
    print("  θ_C haut (0.6) → l'agent est indulgent → plus de vert")

widgets.interact(
    plot_confidence_sigmoid,
    beta=widgets.FloatSlider(value=10, min=2, max=25, step=1,
                              description='β (pente)', continuous_update=False),
    theta_C=widgets.FloatSlider(value=0.3, min=0.1, max=0.6, step=0.05,
                                 description='θ_C (centre)', continuous_update=False)
);

---
## Section 7 — Adaptive Exploration

### The problem

In Section 3, we saw that the agent chooses its actions based on $V(s) = M \cdot R$: it moves toward cells with high value (close to the goal). This is **exploitation** — doing what is known to be good.

But there is a trap: if the agent only exploits, it stays in the areas it already knows and **never discovers the rest of the grid**. It could miss a better path, or never learn what lies on the other side of the wall.

The agent must therefore also **explore** — visit unknown areas to improve its $M$. This is the classic exploration vs. exploitation dilemma.

### The naive solution: epsilon-greedy

The standard approach is to choose a random action with probability $\varepsilon$, and the best action otherwise. But with a fixed $\varepsilon$ (e.g. 0.1), the agent explores **equally** in areas it knows well and in unknown areas. This is wasteful.

### Explore where it matters

The idea of exploring more in uncertain areas is a classic principle in RL (UCB, intrinsic curiosity, etc.). What PRISM brings is using the map $U(s)$ built from the SR's TD errors (Section 5) as the uncertainty signal. The agent knows **where** it is uncertain, and uses it in two complementary ways:

### 1. Adaptive epsilon

$$\varepsilon(s) = \varepsilon_{min} + (\varepsilon_{max} - \varepsilon_{min}) \cdot U(s)$$

Instead of a fixed $\varepsilon$, each cell has its own exploration rate:
- **Well-known area** (low $U$) → $\varepsilon \approx \varepsilon_{min}$ → the agent exploits (it knows what to do)
- **Unknown area** (high $U$) → $\varepsilon \approx \varepsilon_{max}$ → the agent explores (it needs to learn)

### 2. Exploration bonus

$$V_{explore}(s) = V(s) + \lambda \cdot U(s)$$

Uncertainty acts as a **bonus reward**: unknown areas become artificially attractive. The agent is "curious" — it is drawn to what it does not yet know.

- $\lambda = 0$ → no bonus, the agent only exploits
- Large $\lambda$ → the agent prioritizes exploration, even if the goal is elsewhere


In [None]:
def plot_exploration(eps_min, eps_max, lam):
    """Visualise epsilon adaptatif et V_explore."""
    gamma = 0.95
    M_star = grid.true_sr(gamma)
    R = grid.reward_vector()
    V = M_star @ R

    # Simuler U après apprentissage partiel (1000 steps, comme Section 6)
    M_partial = np.eye(grid.n_states)
    traj = grid.random_walk(1000, seed=42)
    error_sum = np.zeros(grid.n_states)
    error_count = np.zeros(grid.n_states)
    all_norms = []
    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t+1]
        delta, M_partial = grid.td_update(M_partial, s, s_next, gamma, 0.1)
        norm = np.linalg.norm(delta)
        error_sum[s] += norm
        error_count[s] += 1
        all_norms.append(norm)
    p99 = np.percentile(all_norms, 99) if all_norms else 1.0
    mean_err = np.zeros(grid.n_states)
    visited = error_count > 0
    mean_err[visited] = error_sum[visited] / error_count[visited]
    U = np.where(visited, np.clip(mean_err / max(p99, 1e-8), 0, 1), 0.8)

    # Epsilon adaptatif
    eps = eps_min + (eps_max - eps_min) * U

    # V_explore
    V_norm = V / max(V.max(), 1e-8)
    V_explore = V_norm + lam * U
    V_explore_norm = V_explore / max(V_explore.max(), 1e-8)

    fig, axes = plt.subplots(1, 3, figsize=(13, 3.5))

    grid.plot(values=eps, ax=axes[0], show_goal=False,
              title=f'ε(s) — taux d\'exploration',
              cmap='YlOrRd', vmin=eps_min, vmax=eps_max)

    grid.plot(values=V_norm, ax=axes[1], show_goal=False,
              title='V(s) — sans bonus', cmap='plasma', vmin=0, vmax=1)

    grid.plot(values=V_explore_norm, ax=axes[2], show_goal=False,
              title=f'V + λ·U — avec bonus (λ={lam})', cmap='plasma', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

    best_V = grid.idx_to_pos[np.argmax(V)]
    best_Vx = grid.idx_to_pos[np.argmax(V_explore)]

    print("Lecture des graphes :")
    print()
    print("Gauche — ε(s), le taux d'exploration par case :")
    print(f"  Rouge = exploration forte (zone inconnue, U élevé → ε ≈ {eps_max})")
    print(f"  Jaune pâle = exploration faible (zone connue, U bas → ε ≈ {eps_min})")
    print(f"  → L'agent fait des actions aléatoires plus souvent dans les zones rouges")
    print()
    print("Centre — V(s), la valeur sans bonus :")
    print(f"  Jaune = case de haute valeur (proche du goal)")
    print(f"  Violet = case de faible valeur (loin du goal)")
    print(f"  → L'agent est attiré uniquement vers le goal")
    print()
    print("Droite — V + λ·U, la valeur avec bonus d'exploration :")
    print(f"  Les zones inconnues (U élevé) reçoivent un bonus → elles deviennent plus jaunes")
    print(f"  → L'agent est attiré à la fois vers le goal ET vers les zones inconnues")
    print()
    print(f"Case la plus attractive : sans bonus = {best_V}, avec bonus (λ={lam}) = {best_Vx}")
    print()
    print("À essayer :")
    print("  λ = 0   → droite identique au centre (pas de curiosité)")
    print("  λ grand (2.0) → les zones inconnues dominent, l'agent ignore presque le goal")
    print("  ε_max haut (0.9) → exploration très agressive dans les zones rouges")

widgets.interact(
    plot_exploration,
    eps_min=widgets.FloatSlider(value=0.01, min=0.001, max=0.1, step=0.01,
                                 description='ε_min', continuous_update=False),
    eps_max=widgets.FloatSlider(value=0.5, min=0.2, max=0.9, step=0.1,
                                 description='ε_max', continuous_update=False),
    lam=widgets.FloatSlider(value=0.5, min=0.0, max=2.0, step=0.1,
                             description='λ (bonus)', continuous_update=False)
);

---
## Section 8 — Change Detection

### The problem

Until now, we assumed the environment does not change: walls stay in place, passages remain open. But what happens if a **wall appears** and blocks the passage between the two rooms?

The $M$ the agent has learned reflects the **old** structure. It predicts, for example, "from the left room, I will often pass through the corridor to reach the right room". But the passage is blocked — this prediction is now **wrong**.

The agent needs to **detect** that something has changed, so it knows it must relearn its $M$.

### How to detect it?

The mechanism is already in place thanks to the previous sections:

1. **The environment changes** → the learned $M$ no longer matches reality
2. **TD errors increase** (Section 4) → the agent makes mistakes when passing through the modified area
3. **$U(s)$ rises** (Section 5) → the error buffer fills with new, high errors
4. **$C(s)$ drops** (Section 6) → confidence falls in the affected area

All that remains is to summarize this information into a single signal: "has the environment just changed?"

### The change score

We look at the average uncertainty of **recently visited** cells:

$$\text{score} = \frac{1}{|S_{recent}|} \sum_{s \in S_{recent}} U(s)$$

- $S_{recent}$ = the cells visited in the last ~50 steps
- In normal times, these cells are well known → low $U$ → low score
- After a change, errors rise → $U$ increases → **the score spikes**

We compare this score to a threshold:

$$\text{change\_detected} = \mathbb{1}(\text{score} > \theta_{change})$$

### Why "recent" states?

We do not look at the entire grid, only the **recently visited** cells, because:
- The change only affects the modified area — cells far from the change keep a low $U$
- The agent can only notice the change where it passes — if it has not visited the modified area, it cannot know yet
- By averaging over recent cells, the signal is **localized**: it reflects what the agent is experiencing right now

The plot below simulates a scenario: the agent learns for 300 steps, then the passage is blocked. We observe the change score over time.


In [None]:
def plot_change_detection(theta_change):
    """Simule apprentissage, perturbation, et détection de changement."""
    gamma, alpha = 0.95, 0.1
    K = 10  # buffer size pour U

    grid_normal = ToyGrid.two_rooms()
    M = np.eye(grid_normal.n_states)
    n_states = grid_normal.n_states

    from collections import deque
    buffers = [deque(maxlen=K) for _ in range(n_states)]
    visit_counts = np.zeros(n_states)
    recent_states = deque(maxlen=50)
    all_deltas = []

    scores = []

    total_steps = 600
    change_at = 300

    traj = grid_normal.random_walk(total_steps, seed=42)

    for t in range(total_steps):
        s = traj[t]
        s_next = traj[t + 1] if t + 1 < len(traj) else s

        delta, M = grid_normal.td_update(M, s, s_next, gamma, alpha)
        norm = np.linalg.norm(delta)

        # Après step 300 : on simule le blocage du passage.
        # En vrai, les transitions changeraient. Ici on simule l'effet :
        # les erreurs TD deviennent plus grandes car M reflète l'ancienne structure.
        if t >= change_at:
            norm *= 3.0

        all_deltas.append(norm)
        p99 = np.percentile(all_deltas[-500:], 99) if len(all_deltas) > 10 else 1.0
        normalized = min(norm / max(p99, 1e-8), 1.0)

        visit_counts[s] += 1
        buffers[s].append(normalized)
        recent_states.append(s)

        # Calculer le score toutes les 5 steps
        if t % 5 == 0 and len(recent_states) > 5:
            unique_recent = set(recent_states)
            U_values = []
            for rs in unique_recent:
                if visit_counts[rs] == 0:
                    U_values.append(0.8)
                elif visit_counts[rs] < K:
                    U_values.append(0.8 * (0.85 ** visit_counts[rs]))
                else:
                    U_values.append(np.mean(buffers[rs]))
            score = np.mean(U_values)
            scores.append((t, score))

    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 4))

    steps, change_scores = zip(*scores)
    ax.plot(steps, change_scores, 'b-', linewidth=1.5, label='Score de changement')
    ax.axhline(y=theta_change, color='red', linestyle='--', linewidth=2,
               label=f'Seuil θ = {theta_change}')
    ax.axvline(x=change_at, color='orange', linestyle='-', linewidth=2,
               alpha=0.7, label='Passage bloqué (step 300)')

    # Colorier la zone de détection (score > seuil)
    for i in range(len(steps)-1):
        if change_scores[i] > theta_change:
            ax.axvspan(steps[i], steps[i+1], alpha=0.15, color='red')

    ax.set_xlabel('Steps (temps)')
    ax.set_ylabel('Score de changement')
    ax.set_title('Scénario : l\'agent apprend, puis le passage est bloqué')
    ax.legend(loc='upper left', fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.05, 1.05)

    # Annotations sur le graphe
    ax.annotate('Apprentissage normal\n(erreurs diminuent)',
                xy=(150, 0.15), fontsize=9, ha='center', color='steelblue')
    ax.annotate('Passage bloqué !\n(erreurs remontent)',
                xy=(450, 0.85), fontsize=9, ha='center', color='red')

    plt.tight_layout()
    plt.show()

    # Latence de détection
    post_change = [(s, sc) for s, sc in scores if s >= change_at and sc > theta_change]

    print("Lecture du graphe :")
    print()
    print("C'est un film : le temps avance de gauche à droite (0 → 600 steps).")
    print()
    print("  Axe horizontal = le temps (nombre de steps d'expérience)")
    print("  Axe vertical   = le score de changement (= moyenne de U sur les cases récentes)")
    print("  Courbe bleue   = le score au fil du temps")
    print(f"  Ligne rouge horizontale = seuil θ_change = {theta_change}")
    print("  Ligne orange verticale  = step 300, le moment où le passage est bloqué")
    print("  Zone rouge transparente = le score dépasse le seuil → changement détecté")
    print()
    print("Le scénario en 2 phases :")
    print("  Steps 0–300   : l'agent explore la grille normalement.")
    print("                  Il apprend M, ses erreurs diminuent, le score baisse.")
    print("  Step 300       : le passage entre les deux pièces est bloqué par un mur.")
    print("                  Le M appris est maintenant faux → les erreurs remontent → le score monte.")
    print()
    if post_change:
        latency = post_change[0][0] - change_at
        print(f"  → Détection après {latency} steps (le score franchit le seuil rouge)")
    else:
        print(f"  → Pas de détection : le score ne franchit jamais le seuil (trop haut ?)")
    print()
    print("À essayer :")
    print("  θ bas (0.2)  → détection très rapide, mais risque de fausses alarmes avant step 300")
    print("  θ haut (0.7) → détection lente ou absente, mais pas de fausses alarmes")

widgets.interact(
    plot_change_detection,
    theta_change=widgets.FloatSlider(value=0.5, min=0.1, max=0.8, step=0.05,
                                      description='θ_change', continuous_update=False)
);

---
## Section 9 — The "I Don't Know" Signal (IDK)

### The need

In Section 6, we constructed $C(s)$: a confidence value between 0 and 1 for each cell. But the agent must make a **binary decision**: "do I trust my prediction here, or not?"

It is like a doctor looking at an X-ray: they may be more or less sure of their diagnosis, but at some point they must decide — "I know what this is" or "I don't know, I will ask for a second opinion".

### The rule

$$\text{IDK}(s) = \mathbb{1}\left(C(s) < \theta_{idk}\right)$$

- If $C(s) \geq \theta_{idk}$ → the agent is confident enough → it **exploits** (does what $V$ tells it)
- If $C(s) < \theta_{idk}$ → the agent is not confident enough → it signals **"I don't know"** and explores instead

$\theta_{idk}$ (default 0.3) is the **minimum confidence threshold**. It is a design choice:
- **Low $\theta_{idk}$** (0.1) → the agent rarely says "I don't know" — it trusts itself even with little certainty
- **High $\theta_{idk}$** (0.6) → the agent is very cautious — it says "I don't know" as soon as confidence is not high


In [None]:
def plot_idk_signal(theta_idk, beta, theta_C):
    """Visualise les zones où l'agent dit 'je ne sais pas'."""
    gamma = 0.95
    M_star = grid.true_sr(gamma)

    # Simuler U après apprentissage partiel (1000 steps, comme Sections 6-7)
    M_partial = np.eye(grid.n_states)
    traj = grid.random_walk(1000, seed=42)
    error_sum = np.zeros(grid.n_states)
    error_count = np.zeros(grid.n_states)
    all_norms = []
    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t+1]
        delta, M_partial = grid.td_update(M_partial, s, s_next, gamma, 0.1)
        norm = np.linalg.norm(delta)
        error_sum[s] += norm
        error_count[s] += 1
        all_norms.append(norm)
    p99 = np.percentile(all_norms, 99) if all_norms else 1.0
    mean_err = np.zeros(grid.n_states)
    visited = error_count > 0
    mean_err[visited] = error_sum[visited] / error_count[visited]
    U = np.where(visited, np.clip(mean_err / max(p99, 1e-8), 0, 1), 0.8)

    # Confiance
    C = 1 / (1 + np.exp(beta * (U - theta_C)))

    # IDK signal
    idk = (C < theta_idk).astype(float)
    pct_idk = 100 * idk.mean()
    n_idk = int(idk.sum())

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # Gauche : C(s) sur la grille
    grid.plot(values=C, ax=axes[0], show_goal=False,
              title='C(s) — Confiance', cmap='RdYlGn', vmin=0, vmax=1)

    # Droite : IDK binaire (vert = confiant, rouge = IDK)
    from matplotlib.colors import ListedColormap
    cmap_idk = ListedColormap(['#2ecc71', '#e74c3c'])
    idk_grid = grid.to_grid(idk)
    axes[1].imshow(idk_grid, cmap=cmap_idk, vmin=0, vmax=1,
                   origin='upper', interpolation='nearest')
    axes[1].set_title(f'IDK (θ = {theta_idk}) — {n_idk}/{grid.n_states} cases en IDK')
    for w in grid.walls:
        axes[1].add_patch(plt.Rectangle((w[1]-0.5, w[0]-0.5), 1, 1,
                          fill=True, color='gray', alpha=0.8))
    axes[1].set_xticks(np.arange(-0.5, grid.cols, 1), minor=True)
    axes[1].set_yticks(np.arange(-0.5, grid.rows, 1), minor=True)
    axes[1].grid(which='minor', color='black', linewidth=0.5, alpha=0.3)
    axes[1].tick_params(which='both', bottom=False, left=False,
                        labelbottom=False, labelleft=False)

    plt.tight_layout()
    plt.show()

    print("Lecture des graphes :")
    print()
    print("Gauche — C(s), la confiance par case (même carte que Section 6) :")
    print("  Vert  = confiance élevée (zone bien apprise)")
    print("  Rouge = confiance faible (zone mal connue)")
    print()
    print("Droite — la décision binaire IDK :")
    print(f"  Vert  = C ≥ {theta_idk} → l'agent fait confiance → il exploite")
    print(f"  Rouge = C < {theta_idk} → l'agent dit \"je ne sais pas\" → il explore")
    print(f"  Gris  = mur")
    print()
    print(f"  → {n_idk} cases en IDK sur {grid.n_states} ({pct_idk:.0f}%)")
    print()
    print("Le lien entre les deux graphes :")
    print(f"  Le seuil θ_idk = {theta_idk} coupe la carte de confiance en deux.")
    print("  Toutes les cases rouges/oranges à gauche (C faible) deviennent rouges à droite (IDK).")
    print("  Toutes les cases vertes à gauche (C élevée) restent vertes à droite.")
    print()
    print("À essayer :")
    print("  θ_idk bas (0.1)  → presque tout est vert (l'agent fait confiance facilement)")
    print("  θ_idk haut (0.6) → beaucoup de rouge (l'agent est très prudent)")
    print("  β et θ_C changent la carte de confiance à gauche → ça change aussi l'IDK à droite")

widgets.interact(
    plot_idk_signal,
    theta_idk=widgets.FloatSlider(value=0.3, min=0.1, max=0.7, step=0.05,
                                   description='θ_idk', continuous_update=False),
    beta=widgets.FloatSlider(value=10, min=2, max=25, step=1,
                              description='β (cf. Sec. 6)', continuous_update=False),
    theta_C=widgets.FloatSlider(value=0.3, min=0.1, max=0.6, step=0.05,
                                 description='θ_C (cf. Sec. 6)', continuous_update=False)
);

---
## Section 10 — Interactive Summary

This dashboard combines the 4 maps from the previous sections on a single simulation. Everything is connected: the agent learns $M$ via TD (Sec. 2), computes $V = M \cdot R$ (Sec. 3), measures its errors to estimate $U$ (Sec. 4-5), and derives its confidence $C$ (Sec. 6).

| Map | Section | What it shows |
|-------|---------|------------------|
| **M[s, :]** | Sec. 1-2 | Predicted visitation frequency from state $s$ |
| **V(s)** | Sec. 3 | Value of each cell ($M \cdot R$) |
| **U(s)** | Sec. 4-5 | Uncertainty (average errors per cell) |
| **C(s)** | Sec. 6 | Confidence (inverse sigmoid of U) |


In [None]:
def plot_full_dashboard(gamma, alpha_M, n_steps, state, beta, theta_C):
    """Tableau de bord complet : M, V, U, C."""
    # Apprentissage TD
    M_star = grid.true_sr(gamma)
    M = np.eye(grid.n_states)

    traj = grid.random_walk(n_steps, seed=42)
    error_sums = np.zeros(grid.n_states)
    visit_counts = np.zeros(grid.n_states)
    all_norms = []

    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t + 1]
        delta, M = grid.td_update(M, s, s_next, gamma, alpha_M)
        norm = np.linalg.norm(delta)
        error_sums[s] += norm
        visit_counts[s] += 1
        all_norms.append(norm)

    # V = M · R
    R = grid.reward_vector()
    V = M @ R

    # U = erreur moyenne normalisée (comme Sections 4-7)
    p99 = np.percentile(all_norms, 99) if all_norms else 1.0
    mean_err = np.zeros(grid.n_states)
    visited = visit_counts > 0
    mean_err[visited] = error_sums[visited] / visit_counts[visited]
    U = np.where(visited, np.clip(mean_err / max(p99, 1e-8), 0, 1), 0.8)

    # C = sigmoïde (comme Section 6)
    C = 1 / (1 + np.exp(beta * (U - theta_C)))

    # Normaliser M et V pour échelles fixes (comme Sections 1-3)
    row = M[state]
    row_norm = row / max(row.max(), 1e-8)
    V_norm = V / max(V.max(), 1e-8)

    fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))

    # M[s, :] — plasma, normalisé (comme Section 1)
    grid.plot(values=row_norm, ax=axes[0], show_goal=False,
              title=f'M depuis s={state}', cmap='plasma', vmin=0, vmax=1)
    pos = grid.idx_to_pos[state]
    axes[0].plot(pos[1], pos[0], 'wo', markersize=10, zorder=5,
                 markeredgecolor='black', markeredgewidth=1.5)

    # V(s) — plasma (comme Section 3)
    grid.plot(values=V_norm, ax=axes[1], show_goal=False,
              title='V = M · R', cmap='plasma', vmin=0, vmax=1)

    # U(s) — YlOrRd (comme Section 4)
    grid.plot(values=U, ax=axes[2], show_goal=False,
              title='U(s)', cmap='YlOrRd', vmin=0, vmax=1)

    # C(s) — RdYlGn (comme Section 6)
    grid.plot(values=C, ax=axes[3], show_goal=False,
              title='C(s)', cmap='RdYlGn', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

    err_m = np.linalg.norm(M - M_star, 'fro') / grid.n_states
    n_confident = (C > 0.5).sum()

    print("Lecture des 4 cartes :")
    print()
    print(f"  M depuis s={state} : fréquence de visite prédite (jaune = souvent, violet = rarement)")
    print(f"    ⚪ = état de départ choisi")
    print(f"  V = M · R : valeur des cases (jaune = proche du goal, violet = loin)")
    print(f"  U(s) : incertitude (rouge = mal connu, jaune pâle = bien appris)")
    print(f"  C(s) : confiance (vert = confiant, rouge = pas confiant)")
    print()
    print(f"Résumé : erreur M = {err_m:.4f}, U moyen = {U.mean():.3f}, "
          f"{n_confident}/{grid.n_states} cases confiantes")
    print()
    print("À essayer :")
    print("  Steps bas (100)   → M mal appris, V faux, U élevé partout, C rouge partout")
    print("  Steps haut (3000) → M converge, V correct, U bas, C vert partout")
    print("  α_M grand (0.5)   → apprentissage instable, U reste élevé")

widgets.interact(
    plot_full_dashboard,
    gamma=widgets.FloatSlider(value=0.95, min=0.5, max=0.99, step=0.01,
                              description='γ', continuous_update=False),
    alpha_M=widgets.FloatSlider(value=0.1, min=0.01, max=0.5, step=0.01,
                                description='α_M (cf. Sec. 2)', continuous_update=False),
    n_steps=widgets.IntSlider(value=500, min=50, max=3000, step=50,
                               description='Steps', continuous_update=False),
    state=widgets.IntSlider(value=0, min=0, max=grid.n_states-1,
                            description='s (départ)'),
    beta=widgets.FloatSlider(value=10, min=2, max=25, step=1,
                              description='β (cf. Sec. 6)', continuous_update=False),
    theta_C=widgets.FloatSlider(value=0.3, min=0.1, max=0.6, step=0.05,
                                 description='θ_C (cf. Sec. 6)', continuous_update=False)
);

---
## Section 11 — Hyperparameter Sweep

### The problem

In the previous sections, we introduced **4 hyperparameters** that control PRISM's metacognitive behavior:

| Parameter | Introduced in | What it controls |
|-----------|-------------|-------------------|
| `U_prior` | Section 5 | Initial uncertainty (before any visit) |
| `decay` | Section 5 | Speed of U's descent in the cold-start regime |
| `beta` | Section 7 | Weight of the exploration bonus in $V_{explore} = V + \beta \cdot U$ |
| `theta_C` | Section 6 | Center of the confidence sigmoid |

Each section presented its parameter in isolation, with a slider. But in reality, **these parameters interact**:

- A high `beta` (strong exploration) can compensate for a slow `decay` (uncertainty that decreases slowly) — the agent explores a lot everywhere, even in areas it knows poorly
- A low `U_prior` (agent not very uncertain at the start) can mask a too-permissive `theta_C` — if U never gets high, the confidence threshold is never reached
- A fast `decay` (0.7) with a low `beta` (5) gives an agent that stops exploring very quickly — potentially before having learned well

**Testing one parameter at a time is not enough.** You may find the "best" `beta` while keeping the others fixed, but this "best" depends on the fixed values. Changing another parameter could make this `beta` suboptimal.

### The solution: the sweep (systematic search)

A **sweep** (or grid search) explores **all combinations** of hyperparameters on a predefined grid.

**Principle**:

1. **Define a grid**: choose 3 values per parameter
   - `U_prior` ∈ {0.5, 0.8, 1.0}
   - `decay` ∈ {0.7, 0.85, 0.95}
   - `beta` ∈ {5, 10, 20}
   - `theta_C` ∈ {0.2, 0.3, 0.5}

2. **Enumerate**: 3 × 3 × 3 × 3 = **81 configurations**

3. **Evaluate** each configuration with **multiple runs** (different seeds) to measure variability

4. **Compare** configurations on a common metric (e.g. ECE in Phase 3)

### Why multiple runs per configuration?

A single run can be lucky or unlucky (the agent stumbles on the goal quickly, or gets stuck in a corner). By running **10 runs** with different seeds, we obtain a distribution: median, standard deviation, confidence interval.

It is like testing a medication: you do not give it to a single person and conclude that it works. You run a trial on 10 (or 100, or 1000) patients, and look at the trend.

### What to look for in the results

The sweep produces a table of 81 rows (one per configuration), with for each row:
- **Median ECE**: calibration error (lower = better)
- **Median MI**: metacognitive index (higher = better)
- **Standard deviation**: stability of the configuration

We look for:
1. **The best configuration**: the one with the lowest median ECE
2. **The rank of the defaults**: is the default configuration (U_prior=0.8, decay=0.85, beta=10, theta_C=0.3) in the top 10, or at the bottom?
3. **Sensitivity**: which parameter has the most impact? If changing `beta` makes the ECE vary from 0.05 to 0.40 while changing `theta_C` only varies it from 0.10 to 0.12, then `beta` is **critical** and `theta_C` is **robust**.

### Analogy: buying a bicycle

You want to buy a bicycle. There are 4 features to choose:
- Frame size (S, M, L)
- Tire type (road, hybrid, off-road)
- Number of gears (7, 14, 21)
- Brake type (rim, disc, hydraulic)

3 × 3 × 3 × 3 = 81 possible combinations. You cannot test an M frame with road tires and conclude that M is the best size — maybe with off-road tires, an L frame would be better.

The sweep amounts to trying **all 81 combinations**, scoring each on 10 different rides, then looking at which one gives the best average score.

### Visualizing the interactions

The widget below simulates a mini-sweep on the ToyGrid. We fix 2 parameters and vary the other 2 to see how the ECE changes. Red cells indicate bad configurations, green ones indicate good configurations.


In [None]:
# Re-import si la cellule est exécutée isolément
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from collections import deque


def simulate_calibration(grid, gamma, alpha_M, n_steps, seed,
                         U_prior, decay, K, beta, theta_C):
    """Simulate learning + compute a simplified ECE-like metric."""
    M = np.eye(grid.n_states)
    traj = grid.random_walk(n_steps, seed=seed)

    buffers = [deque(maxlen=K) for _ in range(grid.n_states)]
    visit_counts = np.zeros(grid.n_states)
    all_norms = []

    for t in range(len(traj) - 1):
        s, s_next = traj[t], traj[t + 1]
        delta, M = grid.td_update(M, s, s_next, gamma, alpha_M)
        norm = np.linalg.norm(delta)
        all_norms.append(norm)
        visit_counts[s] += 1

        p99 = np.percentile(all_norms[-200:], 99) if len(all_norms) > 10 else 1.0
        normalized = min(norm / max(p99, 1e-8), 1.0)
        buffers[s].append(normalized)

    # Compute U(s) for each state
    U = np.zeros(grid.n_states)
    for s in range(grid.n_states):
        if visit_counts[s] == 0:
            U[s] = U_prior
        elif visit_counts[s] < K:
            U[s] = U_prior * (decay ** visit_counts[s])
        else:
            U[s] = np.mean(buffers[s])

    # Confidence C(s)
    C = 1.0 / (1.0 + np.exp(beta * (U - theta_C)))

    # True error: ||M[s,:] - M*[s,:]||
    M_star = grid.true_sr(gamma)
    true_err = np.array([np.linalg.norm(M[s] - M_star[s]) for s in range(grid.n_states)])
    true_err_norm = true_err / max(true_err.max(), 1e-8)

    # Simplified ECE: |C(s) - (1 - true_error(s))| averaged
    accuracy = 1.0 - true_err_norm
    ece = np.mean(np.abs(C - accuracy))
    return ece


def plot_sweep_heatmap(fixed_U_prior, fixed_theta_C):
    """Sweep beta x decay, fixing U_prior and theta_C."""
    gamma, alpha_M, K = 0.95, 0.1, 15
    n_steps = 800
    n_seeds = 3

    beta_vals = [5, 10, 20]
    decay_vals = [0.7, 0.85, 0.95]

    results = np.zeros((len(decay_vals), len(beta_vals)))

    for i, d in enumerate(decay_vals):
        for j, b in enumerate(beta_vals):
            eces = []
            for seed in range(n_seeds):
                ece = simulate_calibration(
                    grid, gamma, alpha_M, n_steps, seed=seed,
                    U_prior=fixed_U_prior, decay=d, K=K,
                    beta=b, theta_C=fixed_theta_C
                )
                eces.append(ece)
            results[i, j] = np.median(eces)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Heatmap
    im = axes[0].imshow(results, cmap='RdYlGn_r', vmin=0, vmax=0.5,
                        origin='upper', aspect='auto')
    axes[0].set_xticks(range(len(beta_vals)))
    axes[0].set_xticklabels([str(b) for b in beta_vals])
    axes[0].set_yticks(range(len(decay_vals)))
    axes[0].set_yticklabels([str(d) for d in decay_vals])
    axes[0].set_xlabel('beta')
    axes[0].set_ylabel('decay')
    axes[0].set_title(f'ECE (U_prior={fixed_U_prior}, theta_C={fixed_theta_C})')

    for i in range(len(decay_vals)):
        for j in range(len(beta_vals)):
            axes[0].text(j, i, f'{results[i,j]:.3f}', ha='center', va='center',
                        fontsize=11, fontweight='bold',
                        color='white' if results[i,j] > 0.25 else 'black')

    plt.colorbar(im, ax=axes[0], label='ECE (bas = mieux)')

    # Bar chart: sensitivity per parameter
    param_ranges = {}
    center = {"U_prior": 0.8, "decay": 0.85, "beta": 10, "theta_C": 0.3}
    sweep_vals = {
        "U_prior": [0.5, 0.8, 1.0],
        "decay": [0.7, 0.85, 0.95],
        "beta": [5, 10, 20],
        "theta_C": [0.2, 0.3, 0.5],
    }

    for param, vals in sweep_vals.items():
        eces_for_param = []
        for v in vals:
            kwargs = dict(center)
            kwargs[param] = v
            ece = simulate_calibration(
                grid, gamma, alpha_M, n_steps, seed=0,
                U_prior=kwargs["U_prior"], decay=kwargs["decay"], K=K,
                beta=kwargs["beta"], theta_C=kwargs["theta_C"]
            )
            eces_for_param.append(ece)
        param_ranges[param] = max(eces_for_param) - min(eces_for_param)

    params = list(param_ranges.keys())
    ranges = [param_ranges[p] for p in params]
    colors = ['#e74c3c' if r > 0.05 else '#f39c12' if r > 0.02 else '#2ecc71'
              for r in ranges]
    axes[1].barh(params, ranges, color=colors, edgecolor='black', linewidth=0.5)
    axes[1].set_xlabel('Plage ECE (max - min)')
    axes[1].set_title('Sensibilite par parametre')
    axes[1].set_xlim(0, max(ranges) * 1.3 if ranges else 0.1)

    for k, (p, r) in enumerate(zip(params, ranges)):
        axes[1].text(r + 0.002, k, f'{r:.3f}', va='center', fontsize=10)

    plt.tight_layout()
    plt.show()

    best_i, best_j = np.unravel_index(results.argmin(), results.shape)
    worst_i, worst_j = np.unravel_index(results.argmax(), results.shape)

    print("Lecture des graphes :")
    print()
    print("Gauche - la heatmap beta x decay :")
    print("  Chaque cellule = ECE mediane pour une combinaison (beta, decay)")
    print(f"  Vert = bonne calibration (ECE bas), rouge = mauvaise (ECE haut)")
    print(f"  Meilleure : beta={beta_vals[best_j]}, decay={decay_vals[best_i]} "
          f"(ECE={results[best_i, best_j]:.3f})")
    print(f"  Pire :      beta={beta_vals[worst_j]}, decay={decay_vals[worst_i]} "
          f"(ECE={results[worst_i, worst_j]:.3f})")
    print()
    print("Droite - la sensibilite :")
    print("  Barre longue = le parametre a un fort impact sur l'ECE")
    print("  Barre courte = le parametre est robuste (peu d'impact)")
    most_sensitive = max(param_ranges, key=param_ranges.get)
    print(f"  Parametre le plus critique : {most_sensitive} "
          f"(plage = {param_ranges[most_sensitive]:.3f})")
    print()
    print("A essayer :")
    print("  Changez U_prior et theta_C avec les sliders :")
    print("  la heatmap change -> les interactions entre beta et decay")
    print("  dependent des valeurs fixees pour les 2 autres parametres.")
    print("  C'est exactement pour ca qu'on fait un sweep complet (81 configs)")
    print("  plutot que d'optimiser chaque parametre isolement.")

widgets.interact(
    plot_sweep_heatmap,
    fixed_U_prior=widgets.FloatSlider(value=0.8, min=0.5, max=1.0, step=0.1,
                                       description='U_prior (fixe)',
                                       continuous_update=False),
    fixed_theta_C=widgets.FloatSlider(value=0.3, min=0.2, max=0.5, step=0.1,
                                       description='theta_C (fixe)',
                                       continuous_update=False)
);

---
## Formula Summary

| Concept | Formula | Key Parameters |
|---------|---------|----------------|
| SR (definition) | $M(s,s') = \mathbb{E}[\sum_t \gamma^t \mathbb{1}(s_t=s')]$ | $\gamma$ |
| SR (analytical) | $M^* = (I - \gamma T)^{-1}$ | $\gamma$, $T$ |
| TD error | $\delta_M = e(s') + \gamma M(s',:) - M(s,:)$ | $\gamma$ |
| SR update | $M(s,:) \leftarrow M(s,:) + \alpha_M \delta_M$ | $\alpha_M$ |
| Value | $V(s) = M(s,:) \cdot R$ | |
| Reward | $R(s) \leftarrow R(s) + \alpha_R (r - R(s))$ | $\alpha_R$ |
| Uncertainty | $U(s) = \text{mean}(\text{buffer}(s))$ | $K$, $U_{prior}$, $decay$ |
| Confidence | $C(s) = \frac{1}{1+\exp(\beta(U-\theta_C))}$ | $\beta$, $\theta_C$ |
| Adaptive epsilon | $\varepsilon(s) = \varepsilon_{min} + (\varepsilon_{max}-\varepsilon_{min}) U(s)$ | $\varepsilon_{min}$, $\varepsilon_{max}$ |
| Exploration bonus | $V_{explore} = V + \lambda U$ | $\lambda$ |
| Change detection | $\text{score} = \text{mean}(U(s_{recent}))$ | $\theta_{change}$ |
| IDK | $\mathbb{1}(C(s) < \theta_{idk})$ | $\theta_{idk}$ |
| Sweep | Grid search over $(U_{prior}, decay, \beta, \theta_C)$: 81 configs × $n$ runs | Grid of values |

**Appendices**: [Spectral](00a_spectral_deep_dive.ipynb) | [Calibration](00b_calibration_methods.ipynb)
