![Logo](https://github.com/BartaZoltan/deep-reinforcement-learning-course/blob/main/website/assets/logo.png?raw=1)

Made by **Zoltán Barta**

[<img src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/BartaZoltan/deep-reinforcement-learning-course/blob/main/notebooks/sessions/session_02_mdp_dynamic_programming/session_02_mdp_dynamic_programming.ipynb)

# Markov Decision Processes

This practice session follows Chapters 3–4 of *Reinforcement Learning: An Introduction (2nd ed.)* by Sutton & Barto.

Focus: finite MDPs + planning with a known model via **Dynamic Programming**:
- Iterative policy evaluation
- Policy iteration
- Value iteration

We’ll test these on:
- A small **Gridworld** (deterministic dynamics, step cost)
- The **Gambler’s Problem** (value iteration + optimal policy)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable, Optional

np.set_printoptions(precision=3, suppress=True)

## A tiny tabular MDP representation

We’ll store the dynamics as a list of transitions per $(s,a)$:
- each transition is $(p, s', r, done)$
- terminal states are handled via `done=True` transitions

In [None]:
Transition = Tuple[float, int, float, bool]  # (p, s_next, reward, done)

@dataclass
class TabularMDP:
    nS: int
    nA: int
    P: List[List[List[Transition]]]  # P[s][a] = list of transitions
    gamma: float = 1.0

    def transitions(self, s: int, a: int) -> List[Transition]:
        return self.P[s][a]


def is_stochastic_policy(policy: np.ndarray, nS: int, nA: int) -> None:
    assert policy.shape == (nS, nA)
    row_sums = policy.sum(axis=1)
    if not np.allclose(row_sums, 1.0):
        raise ValueError('Policy rows must sum to 1.')
    if np.any(policy < -1e-12):
        raise ValueError('Policy must have non-negative probabilities.')


def greedy_policy_from_q(Q: np.ndarray, tie_break: str = 'uniform') -> np.ndarray:
    """Return a deterministic (or uniform-tie) greedy policy from Q[s,a]."""
    nS, nA = Q.shape
    policy = np.zeros((nS, nA), dtype=float)
    best_actions = (Q == Q.max(axis=1, keepdims=True))
    if tie_break == 'uniform':
        counts = best_actions.sum(axis=1, keepdims=True)
        policy[best_actions] = (1.0 / counts)[best_actions]
    elif tie_break == 'first':
        a_star = np.argmax(Q, axis=1)
        policy[np.arange(nS), a_star] = 1.0
    else:
        raise ValueError("tie_break must be 'uniform' or 'first'")
    return policy

## Dynamic Programming algorithms (planning with a known model)

All methods below assume the MDP dynamics $p(s',r\mid s,a)$ are known.

In [None]:
def policy_evaluation(
    mdp: TabularMDP,
    policy: np.ndarray,
    theta: float = 1e-10,
    max_iterations: int = 1_000_000,
    V0: Optional[np.ndarray] = None,
    in_place: bool = True,
):
    """Iterative policy evaluation for V^pi.

    Args:
        mdp: TabularMDP
        policy: shape (nS, nA), stochastic policy
        theta: stopping tolerance on max|V_new - V_old|
        max_iterations: safety cap
        V0: optional initial value function
        in_place: if True, update V(s) immediately (Gauss-Seidel-ish)
    """
    nS, nA = mdp.nS, mdp.nA
    is_stochastic_policy(policy, nS, nA)
    V = np.zeros(nS) if V0 is None else V0.astype(float).copy()
    for it in range(max_iterations):
        delta = 0.0
        if in_place:
            for s in range(nS):
                v_old = V[s]
                v_new = 0.0
                for a in range(nA):
                    pi = policy[s, a]
                    if pi == 0.0:
                        continue
                    for p, s2, r, done in mdp.transitions(s, a):
                        v_new += pi * p * (r + (0.0 if done else mdp.gamma * V[s2]))
                V[s] = v_new
                delta = max(delta, abs(v_old - v_new))
        else:
            V_new = V.copy()
            for s in range(nS):
                v_new = 0.0
                for a in range(nA):
                    pi = policy[s, a]
                    if pi == 0.0:
                        continue
                    for p, s2, r, done in mdp.transitions(s, a):
                        v_new += pi * p * (r + (0.0 if done else mdp.gamma * V[s2]))
                V_new[s] = v_new
            delta = np.max(np.abs(V_new - V))
            V = V_new
        if delta < theta:
            return V, it + 1
    return V, max_iterations


def q_from_v(mdp: TabularMDP, V: np.ndarray) -> np.ndarray:
    nS, nA = mdp.nS, mdp.nA
    Q = np.zeros((nS, nA), dtype=float)
    for s in range(nS):
        for a in range(nA):
            q = 0.0
            for p, s2, r, done in mdp.transitions(s, a):
                q += p * (r + (0.0 if done else mdp.gamma * V[s2]))
            Q[s, a] = q
    return Q


def policy_iteration(
    mdp: TabularMDP,
    theta: float = 1e-10,
    max_policy_iterations: int = 10_000,
    tie_break: str = 'uniform',
):
    """Classic policy iteration: evaluate -> greedy improvement until stable."""
    nS, nA = mdp.nS, mdp.nA
    policy = np.ones((nS, nA), dtype=float) / nA
    V = np.zeros(nS, dtype=float)
    eval_iters_total = 0

    for pi_it in range(max_policy_iterations):
        V, eval_iters = policy_evaluation(mdp, policy, theta=theta, V0=V, in_place=True)
        eval_iters_total += eval_iters
        Q = q_from_v(mdp, V)
        new_policy = greedy_policy_from_q(Q, tie_break=tie_break)
        if np.allclose(new_policy, policy):
            return policy, V, pi_it + 1, eval_iters_total
        policy = new_policy
    return policy, V, max_policy_iterations, eval_iters_total


def value_iteration(
    mdp: TabularMDP,
    theta: float = 1e-10,
    max_iterations: int = 1_000_000,
):
    """Value iteration for V* (Bellman optimality updates)."""
    nS, nA = mdp.nS, mdp.nA
    V = np.zeros(nS, dtype=float)
    for it in range(max_iterations):
        delta = 0.0
        for s in range(nS):
            v_old = V[s]
            q_sa = []
            for a in range(nA):
                q = 0.0
                for p, s2, r, done in mdp.transitions(s, a):
                    q += p * (r + (0.0 if done else mdp.gamma * V[s2]))
                q_sa.append(q)
            V[s] = float(np.max(q_sa))
            delta = max(delta, abs(v_old - V[s]))
        if delta < theta:
            break
    Q = q_from_v(mdp, V)
    policy = greedy_policy_from_q(Q, tie_break='uniform')
    return policy, V, it + 1

In [None]:
ACTION_SYMBOLS = {0: '↑', 1: '→', 2: '↓', 3: '←'}

def as_grid(arr: np.ndarray, rows: int, cols: int) -> np.ndarray:
    return arr.reshape(rows, cols)


def print_value_grid(V: np.ndarray, rows: int, cols: int, title: str = 'V') -> None:
    grid = as_grid(V, rows, cols)
    print(title)
    for r in range(rows):
        row = ' '.join([f"{grid[r, c]:7.2f}" for c in range(cols)])
        print(row)


def policy_to_arrow_grid(policy: np.ndarray, rows: int, cols: int, terminals: Iterable[int] = ()) -> List[List[str]]:
    nS, nA = policy.shape
    assert nA == 4, 'Arrow rendering expects 4 actions: up,right,down,left'
    terminals = set(terminals)
    arrows = [['' for _ in range(cols)] for _ in range(rows)]
    for s in range(nS):
        r, c = divmod(s, cols)
        if s in terminals:
            arrows[r][c] = 'T'
            continue
        best = np.flatnonzero(policy[s] == policy[s].max())
        if len(best) == 1:
            arrows[r][c] = ACTION_SYMBOLS[int(best[0])]
        else:
            arrows[r][c] = ''.join(ACTION_SYMBOLS[int(a)] for a in best)
    return arrows


def print_policy_grid(policy: np.ndarray, rows: int, cols: int, terminals: Iterable[int] = (), title: str = 'policy') -> None:
    arrows = policy_to_arrow_grid(policy, rows, cols, terminals=terminals)
    print(title)
    for r in range(rows):
        print(' '.join([f"{arrows[r][c]:4s}" for c in range(cols)]))

## Example 1 — 4×4 Gridworld (Sutton & Barto-style)

- States are grid cells; two terminal states in the corners.
- Actions: up/right/down/left (deterministic).
- Reward: `-1` per step until termination.
- Discount: $\gamma = 1$ (episodic).

In [None]:
def build_gridworld_mdp(
    rows: int = 4,
    cols: int = 4,
    terminal_states: Optional[Iterable[int]] = None,
    step_reward: float = -1.0,
    gamma: float = 1.0,
) -> Tuple[TabularMDP, List[int]]:
    """Deterministic gridworld with absorbing terminal states."""
    nS = rows * cols
    nA = 4  # up, right, down, left
    if terminal_states is None:
        terminal_states = [0, nS - 1]
    terminal_states = list(terminal_states)
    terminal_set = set(terminal_states)
    P: List[List[List[Transition]]] = [[[] for _ in range(nA)] for _ in range(nS)]

    def move(s: int, a: int) -> int:
        r, c = divmod(s, cols)
        if a == 0:
            r2, c2 = max(r - 1, 0), c
        elif a == 1:
            r2, c2 = r, min(c + 1, cols - 1)
        elif a == 2:
            r2, c2 = min(r + 1, rows - 1), c
        elif a == 3:
            r2, c2 = r, max(c - 1, 0)
        else:
            raise ValueError('Invalid action')
        return r2 * cols + c2

    for s in range(nS):
        for a in range(nA):
            if s in terminal_set:
                P[s][a] = [(1.0, s, 0.0, True)]
                continue
            s2 = move(s, a)
            done = s2 in terminal_set
            P[s][a] = [(1.0, s2, step_reward, done)]

    return TabularMDP(nS=nS, nA=nA, P=P, gamma=gamma), terminal_states


grid_mdp, grid_terminals = build_gridworld_mdp()
grid_rows, grid_cols = 4, 4

### 1) Iterative policy evaluation (uniform random policy)

In [None]:
random_policy = np.ones((grid_mdp.nS, grid_mdp.nA), dtype=float) / grid_mdp.nA
V_pi, n_eval_iters = policy_evaluation(grid_mdp, random_policy, theta=1e-12)
print(f"Policy evaluation iterations: {n_eval_iters}")
print_value_grid(V_pi, grid_rows, grid_cols, title='V under uniform random policy')

### 2) Policy iteration (compute an optimal policy)

In [None]:
pi_star, V_star_pi, n_pi_iters, n_eval_total = policy_iteration(grid_mdp, theta=1e-12)
print(f"Policy iteration outer loops: {n_pi_iters}, total eval sweeps: {n_eval_total}")
print_value_grid(V_star_pi, grid_rows, grid_cols, title='V* (from policy iteration)')
print_policy_grid(pi_star, grid_rows, grid_cols, terminals=grid_terminals, title='π* (policy iteration)')

### 3) Value iteration (compute an optimal value function directly)

In [None]:
pi_vi, V_vi, n_vi_iters = value_iteration(grid_mdp, theta=1e-12)
print(f"Value iteration sweeps: {n_vi_iters}")
print_value_grid(V_vi, grid_rows, grid_cols, title='V* (from value iteration)')
print_policy_grid(pi_vi, grid_rows, grid_cols, terminals=grid_terminals, title='π* (value iteration)')

print('Max |V_pi_iter - V_value_iter| =', np.max(np.abs(V_star_pi - V_vi)))

## Example 2 — The Gambler’s Problem (Value Iteration)

A gambler has capital $s\in\{0,1,\dots,100\}$.
- $s=0$ and $s=100$ are terminal.
- At each step, choose a stake $a \in \{1,\dots,\min(s, 100-s)\}$.
- With probability $p_h$ you win and $s \leftarrow s+a$, otherwise $s \leftarrow s-a$.
- Reward is 1 only upon reaching $s=100$ (so the value is the probability of eventual success).

In [None]:
def gamblers_value_iteration(p_h: float = 0.4, theta: float = 1e-12, max_iterations: int = 1_000_000):
    """Value iteration for the Gambler's Problem (Sutton & Barto, Ch. 4).
    Returns (V, policy) where policy[s] is the chosen stake for capital s.
    """
    if not (0.0 < p_h < 1.0):
        raise ValueError('p_h must be in (0,1)')

    goal = 100
    V = np.zeros(goal + 1, dtype=float)
    V[goal] = 1.0  # reaching the goal yields reward 1; value is prob of success
    policy = np.zeros(goal + 1, dtype=int)

    for it in range(max_iterations):
        delta = 0.0
        for s in range(1, goal):
            stakes = np.arange(1, min(s, goal - s) + 1)
            if stakes.size == 0:
                continue
            # expected return: no living reward; terminal value already encodes success reward
            action_returns = p_h * V[s + stakes] + (1.0 - p_h) * V[s - stakes]
            v_new = np.max(action_returns)
            delta = max(delta, abs(v_new - V[s]))
            V[s] = v_new
        if delta < theta:
            break

    # derive greedy policy from converged V
    for s in range(1, goal):
        stakes = np.arange(1, min(s, goal - s) + 1)
        if stakes.size == 0:
            continue
        action_returns = p_h * V[s + stakes] + (1.0 - p_h) * V[s - stakes]
        best = stakes[np.flatnonzero(action_returns == action_returns.max())]
        policy[s] = int(best[0])  # pick smallest stake among ties for determinism
    return V, policy, it + 1


V_g, pi_g, n_g_iters = gamblers_value_iteration(p_h=0.4, theta=1e-12)
print(f"Gambler's value iteration sweeps: {n_g_iters}")

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(V_g)
ax[0].set_title('Gambler: V(s)')
ax[0].set_xlabel('capital s')
ax[0].set_ylabel('value')

ax[1].step(np.arange(len(pi_g)), pi_g, where='mid')
ax[1].set_title('Gambler: greedy stake π(s)')
ax[1].set_xlabel('capital s')
ax[1].set_ylabel('stake')
plt.tight_layout()
plt.show()

## A bit more depth: Bellman equations + convergence

### Bellman expectation equation (policy evaluation)
For a fixed policy $\pi$, the value function is the unique solution to:
$$
V^{\pi}(s)=\sum_a \pi(a\mid s)\sum_{s',r} p(s',r\mid s,a)\big[r + \gamma V^{\pi}(s')\big].
$$

### Bellman optimality equation (control)
The optimal value function satisfies:
$$
V^*(s)=\max_a \sum_{s',r} p(s',r\mid s,a)\big[r + \gamma V^*(s')\big].
$$

Both policy evaluation and value iteration are iterative methods that repeatedly apply these “backup” operators until the maximum change is below a tolerance `theta`.

In [None]:
def policy_evaluation_with_deltas(
    mdp: TabularMDP,
    policy: np.ndarray,
    theta: float = 1e-10,
    max_iterations: int = 1_000_000,
):
    """Same as policy_evaluation, but also returns per-iteration deltas for diagnostics."""
    nS, nA = mdp.nS, mdp.nA
    is_stochastic_policy(policy, nS, nA)
    V = np.zeros(nS, dtype=float)
    deltas = []
    for _ in range(max_iterations):
        delta = 0.0
        for s in range(nS):
            v_old = V[s]
            v_new = 0.0
            for a in range(nA):
                pi = policy[s, a]
                if pi == 0.0:
                    continue
                for p, s2, r, done in mdp.transitions(s, a):
                    v_new += pi * p * (r + (0.0 if done else mdp.gamma * V[s2]))
            V[s] = v_new
            delta = max(delta, abs(v_old - v_new))
        deltas.append(delta)
        if delta < theta:
            break
    return V, np.array(deltas)


def value_iteration_with_deltas(
    mdp: TabularMDP,
    theta: float = 1e-10,
    max_iterations: int = 1_000_000,
):
    """Same as value_iteration, but returns per-iteration deltas for diagnostics."""
    nS, nA = mdp.nS, mdp.nA
    V = np.zeros(nS, dtype=float)
    deltas = []
    for _ in range(max_iterations):
        delta = 0.0
        for s in range(nS):
            v_old = V[s]
            q_sa = []
            for a in range(nA):
                q = 0.0
                for p, s2, r, done in mdp.transitions(s, a):
                    q += p * (r + (0.0 if done else mdp.gamma * V[s2]))
                q_sa.append(q)
            V[s] = float(np.max(q_sa))
            delta = max(delta, abs(v_old - V[s]))
        deltas.append(delta)
        if delta < theta:
            break
    Q = q_from_v(mdp, V)
    policy = greedy_policy_from_q(Q, tie_break='uniform')
    return policy, V, np.array(deltas)

### Convergence diagnostics on Gridworld

Let’s look at how the max-update `delta` decays over sweeps.

In [None]:
V_tmp, deltas_eval = policy_evaluation_with_deltas(grid_mdp, random_policy, theta=1e-12)
_, _, deltas_vi = value_iteration_with_deltas(grid_mdp, theta=1e-12)

plt.figure(figsize=(8, 3))
plt.semilogy(deltas_eval, label='policy evaluation (random π)')
plt.semilogy(deltas_vi, label='value iteration')
plt.xlabel('sweep')
plt.ylabel('max state update (delta)')
plt.title('Convergence diagnostics (Gridworld)')
plt.legend()
plt.tight_layout()
plt.show()

### Better visualization of the Gridworld value function

Printing numbers is fine, but a heatmap is easier to scan.

In [None]:
def plot_value_heatmap(V: np.ndarray, rows: int, cols: int, title: str):
    grid = as_grid(V, rows, cols)
    plt.figure(figsize=(4.2, 3.6))
    im = plt.imshow(grid, cmap='viridis')
    plt.title(title)
    plt.colorbar(im, fraction=0.046, pad=0.04)
    for r in range(rows):
        for c in range(cols):
            plt.text(c, r, f"{grid[r,c]:.1f}", ha='center', va='center', color='white', fontsize=9)
    plt.xticks(range(cols))
    plt.yticks(range(rows))
    plt.tight_layout()
    plt.show()


plot_value_heatmap(V_pi, grid_rows, grid_cols, title='Gridworld: V under random π')
plot_value_heatmap(V_vi, grid_rows, grid_cols, title='Gridworld: V* (value iteration)')

### Sensitivity experiment: change $\gamma$ and step reward

DP makes it easy to see how modeling choices (discounting and living reward) change the optimal values/policy.

In [None]:
settings = [
    {'gamma': 1.0, 'step_reward': -1.0, 'label': 'γ=1.0, r=-1'},
    {'gamma': 0.9, 'step_reward': -1.0, 'label': 'γ=0.9, r=-1'},
    {'gamma': 1.0, 'step_reward': -0.1, 'label': 'γ=1.0, r=-0.1'},
]

fig, axes = plt.subplots(1, len(settings), figsize=(12, 3.6))
for ax, cfg in zip(axes, settings):
    mdp_cfg, terminals_cfg = build_gridworld_mdp(gamma=cfg['gamma'], step_reward=cfg['step_reward'])
    _, V_cfg, _ = value_iteration(mdp_cfg, theta=1e-12)
    grid = as_grid(V_cfg, 4, 4)
    im = ax.imshow(grid, cmap='viridis')
    ax.set_title(cfg['label'])
    ax.set_xticks(range(4))
    ax.set_yticks(range(4))
    for r in range(4):
        for c in range(4):
            ax.text(c, r, f"{grid[r,c]:.1f}", ha='center', va='center', color='white', fontsize=8)
fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
plt.suptitle('Gridworld: how modeling choices change V*')
plt.tight_layout()
plt.show()

## Suggested exercises (keep it interactive)

1. Gridworld: change terminal states (e.g., add more terminals) and see how $V^*$ changes.
2. Gridworld: make actions stochastic (e.g., 0.8 intended move, 0.2 slip) and re-run value iteration.
3. Gambler: try different $p_h$ values (e.g., 0.25, 0.55) and compare the optimal stake patterns.
4. Compare tolerances: run value iteration with `theta=1e-2, 1e-6, 1e-12` and compare sweeps + resulting value error.

## Example 3 — Stochastic Gridworld ("slip" dynamics)

A common next step is to make actions **stochastic**: with probability $(1-\epsilon)$ you move as intended, and with probability $\epsilon$ you “slip” into one of the other actions. This keeps the problem tabular, but makes it feel more like real environments (uncertainty in control).

In [None]:
def build_gridworld_mdp_slip(
    rows: int = 4,
    cols: int = 4,
    terminal_states: Optional[Iterable[int]] = None,
    step_reward: float = -1.0,
    gamma: float = 1.0,
    slip_epsilon: float = 0.2,
) -> Tuple[TabularMDP, List[int]]:
    """Gridworld where actions slip to other actions with probability slip_epsilon."""
    if not (0.0 <= slip_epsilon <= 1.0):
        raise ValueError('slip_epsilon must be in [0,1]')
    nS = rows * cols
    nA = 4
    if terminal_states is None:
        terminal_states = [0, nS - 1]
    terminal_states = list(terminal_states)
    terminal_set = set(terminal_states)
    P: List[List[List[Transition]]] = [[[] for _ in range(nA)] for _ in range(nS)]

    def move(s: int, a: int) -> int:
        r, c = divmod(s, cols)
        if a == 0:
            r2, c2 = max(r - 1, 0), c
        elif a == 1:
            r2, c2 = r, min(c + 1, cols - 1)
        elif a == 2:
            r2, c2 = min(r + 1, rows - 1), c
        elif a == 3:
            r2, c2 = r, max(c - 1, 0)
        else:
            raise ValueError('Invalid action')
        return r2 * cols + c2

    for s in range(nS):
        for a in range(nA):
            if s in terminal_set:
                P[s][a] = [(1.0, s, 0.0, True)]
                continue
            transitions: Dict[Tuple[int, float, bool], float] = {}
            # intended action
            intended = move(s, a)
            done_intended = intended in terminal_set
            key = (intended, step_reward, done_intended)
            transitions[key] = transitions.get(key, 0.0) + (1.0 - slip_epsilon)
            # slip to other actions uniformly
            if slip_epsilon > 0.0:
                others = [aa for aa in range(nA) if aa != a]
                p_other = slip_epsilon / len(others)
                for aa in others:
                    s2 = move(s, aa)
                    done = s2 in terminal_set
                    key = (s2, step_reward, done)
                    transitions[key] = transitions.get(key, 0.0) + p_other
            P[s][a] = [(p, s2, r, done) for (s2, r, done), p in transitions.items()]

    return TabularMDP(nS=nS, nA=nA, P=P, gamma=gamma), terminal_states


slip_mdp, slip_terminals = build_gridworld_mdp_slip(slip_epsilon=0.2)
pi_slip, V_slip, n_slip = value_iteration(slip_mdp, theta=1e-12)
print(f"Slippery gridworld value iteration sweeps: {n_slip}")
plot_value_heatmap(V_slip, 4, 4, title='Slippery Gridworld (ε=0.2): V*')
print_policy_grid(pi_slip, 4, 4, terminals=slip_terminals, title='Slippery Gridworld: π*')

## Modified Policy Iteration (optional, but great practice)

Sutton & Barto discuss *modified policy iteration*: instead of evaluating a policy to full convergence each time, do only a small number of evaluation sweeps, then improve the policy again. This often gives a good speed/compute tradeoff.

In [None]:
def modified_policy_iteration(
    mdp: TabularMDP,
    eval_sweeps: int = 3,
    theta: float = 1e-10,
    max_iterations: int = 10_000,
    tie_break: str = 'uniform',
):
    nS, nA = mdp.nS, mdp.nA
    policy = np.ones((nS, nA), dtype=float) / nA
    V = np.zeros(nS, dtype=float)

    for it in range(max_iterations):
        # truncated evaluation: do a fixed number of sweeps
        for _ in range(eval_sweeps):
            V, _ = policy_evaluation(mdp, policy, theta=0.0, max_iterations=1, V0=V, in_place=True)
        # improvement
        Q = q_from_v(mdp, V)
        new_policy = greedy_policy_from_q(Q, tie_break=tie_break)
        if np.allclose(new_policy, policy):
            # optional final cleanup evaluation
            V, _ = policy_evaluation(mdp, policy, theta=theta, V0=V, in_place=True)
            return policy, V, it + 1
        policy = new_policy
    return policy, V, max_iterations


mpi_policy, mpi_V, mpi_iters = modified_policy_iteration(grid_mdp, eval_sweeps=3, theta=1e-12)
print(f"Modified policy iteration outer loops: {mpi_iters}")
print('Max |V_MPI - V_PI| =', np.max(np.abs(mpi_V - V_star_pi)))

## Example 4 — Jack’s Car Rental ("car moving" problem)

This is the classic DP control problem from Sutton & Barto Ch. 4 (often misremembered as a “car salesman” task):
- Two locations with car rental demand and returns.
- Each night you can move cars between locations at a cost.
- During the day, cars are rented out stochastically, producing reward.

The original book uses `max_cars=20` and Poisson demand/returns; that can be slow in pure Python. Below is a **scaled-down but faithful** implementation that runs fast, and you can increase sizes later.

In [None]:
import math
from functools import lru_cache

def poisson_pmf_truncated(lam: float, n_max: int) -> np.ndarray:
    """Return pmf[0..n_max] with tail probability folded into pmf[n_max]."""
    pmf = np.zeros(n_max + 1, dtype=float)
    pmf[0] = math.exp(-lam)
    for k in range(1, n_max + 1):
        pmf[k] = pmf[k - 1] * lam / k
    # fold tail into last bucket
    tail = max(0.0, 1.0 - pmf.sum())
    pmf[-1] += tail
    return pmf


def precompute_location_tables(
    max_cars: int,
    lam_req: float,
    lam_ret: float,
    poisson_max: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """Precompute per-location transition probabilities and (rentals * prob) numerators.

    Returns:
        p_next[c, c2] = P(next_cars=c2 | start=c)
        rent_num[c, c2] = E[rentals * 1{next_cars=c2} | start=c]
    """
    req_pmf = poisson_pmf_truncated(lam_req, poisson_max)
    ret_pmf = poisson_pmf_truncated(lam_ret, poisson_max)
    p_next = np.zeros((max_cars + 1, max_cars + 1), dtype=float)
    rent_num = np.zeros((max_cars + 1, max_cars + 1), dtype=float)

    for c in range(max_cars + 1):
        for req, p_req in enumerate(req_pmf):
            rentals = min(c, req)
            c_after_rent = c - rentals
            for ret, p_ret in enumerate(ret_pmf):
                c2 = min(max_cars, c_after_rent + ret)
                p = p_req * p_ret
                p_next[c, c2] += p
                rent_num[c, c2] += rentals * p
        # normalize minor numeric drift
        s = p_next[c].sum()
        if not np.isclose(s, 1.0):
            p_next[c] /= s
            rent_num[c] /= s
    return p_next, rent_num


def jacks_policy_evaluation(
    V: np.ndarray,
    policy: np.ndarray,
    p1: np.ndarray,
    n1: np.ndarray,
    p2: np.ndarray,
    n2: np.ndarray,
    rental_reward: float,
    move_cost: float,
    gamma: float,
    max_move: int,
    theta: float = 1e-4,
    max_iterations: int = 10_000,
    in_place: bool = True,
) -> Tuple[np.ndarray, int]:
    """Policy evaluation for Jack's Car Rental using precomputed per-location tables."""
    max_cars = V.shape[0] - 1
    V = V.copy().astype(float)
    actions = np.arange(-max_move, max_move + 1)
    nS = (max_cars + 1) ** 2
    for it in range(max_iterations):
        delta = 0.0
        V_old = V.copy() if not in_place else None
        for i in range(max_cars + 1):
            for j in range(max_cars + 1):
                v_prev = V[i, j]
                a = int(policy[i, j])
                a = int(np.clip(a, -min(max_move, j), min(max_move, i)))
                i1, j1 = i - a, j + a
                # expected return under action a
                total = 0.0
                for i2 in range(max_cars + 1):
                    p_i = p1[i1, i2]
                    if p_i == 0.0:
                        continue
                    for j2 in range(max_cars + 1):
                        p_j = p2[j1, j2]
                        if p_j == 0.0:
                            continue
                        prob = p_i * p_j
                        reward_weighted = rental_reward * (n1[i1, i2] * p_j + n2[j1, j2] * p_i) - move_cost * abs(a) * prob
                        v_next = (0.0 if V_old is None else V_old[i2, j2])
                        if V_old is None:
                            v_next = V[i2, j2]
                        total += reward_weighted + gamma * prob * v_next
                V[i, j] = total
                delta = max(delta, abs(v_prev - total))
        if delta < theta:
            return V, it + 1
    return V, max_iterations


def jacks_policy_improvement(
    V: np.ndarray,
    p1: np.ndarray,
    n1: np.ndarray,
    p2: np.ndarray,
    n2: np.ndarray,
    rental_reward: float,
    move_cost: float,
    gamma: float,
    max_move: int,
) -> Tuple[np.ndarray, bool]:
    """Greedy improvement step. Returns (new_policy, stable)."""
    max_cars = V.shape[0] - 1
    new_policy = np.zeros((max_cars + 1, max_cars + 1), dtype=int)
    stable = True
    for i in range(max_cars + 1):
        for j in range(max_cars + 1):
            allowed = np.arange(-min(max_move, j), min(max_move, i) + 1)
            best_a = 0
            best_val = -1e100
            for a in allowed:
                i1, j1 = i - a, j + a
                total = 0.0
                for i2 in range(max_cars + 1):
                    p_i = p1[i1, i2]
                    if p_i == 0.0:
                        continue
                    for j2 in range(max_cars + 1):
                        p_j = p2[j1, j2]
                        if p_j == 0.0:
                            continue
                        prob = p_i * p_j
                        reward_weighted = rental_reward * (n1[i1, i2] * p_j + n2[j1, j2] * p_i) - move_cost * abs(a) * prob
                        total += reward_weighted + gamma * prob * V[i2, j2]
                if total > best_val + 1e-12:
                    best_val = total
                    best_a = int(a)
            new_policy[i, j] = best_a
    return new_policy, stable  # stability checked in policy iteration loop


def jacks_policy_iteration(
    max_cars: int = 10,
    max_move: int = 5,
    poisson_max: int = 8,
    lam_req1: float = 3.0,
    lam_req2: float = 4.0,
    lam_ret1: float = 3.0,
    lam_ret2: float = 2.0,
    rental_reward: float = 10.0,
    move_cost: float = 2.0,
    gamma: float = 0.9,
    theta: float = 1e-3,
    max_outer: int = 50,
    eval_in_place: bool = True,
) -> Tuple[np.ndarray, np.ndarray, List[float]]:
    """Run policy iteration for Jack's Car Rental (scaled-down)."""
    p1, n1 = precompute_location_tables(max_cars, lam_req1, lam_ret1, poisson_max)
    p2, n2 = precompute_location_tables(max_cars, lam_req2, lam_ret2, poisson_max)

    V = np.zeros((max_cars + 1, max_cars + 1), dtype=float)
    policy = np.zeros((max_cars + 1, max_cars + 1), dtype=int)
    history = []
    for outer in range(max_outer):
        V, eval_iters = jacks_policy_evaluation(V, policy, p1, n1, p2, n2, rental_reward, move_cost, gamma, max_move, theta=theta, in_place=eval_in_place)
        new_policy, _ = jacks_policy_improvement(V, p1, n1, p2, n2, rental_reward, move_cost, gamma, max_move)
        change = np.max(np.abs(new_policy - policy))
        history.append(float(change))
        if np.array_equal(new_policy, policy):
            break
        policy = new_policy
    return policy, V, history


def plot_jacks_policy_and_value(policy: np.ndarray, V: np.ndarray, title_prefix: str = "Jack's") -> None:
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    im1 = plt.imshow(V.T, origin='lower', cmap='viridis')
    plt.colorbar(im1, fraction=0.046, pad=0.04)
    plt.title(f"{title_prefix}: value V(i,j)")
    plt.xlabel('cars at loc 1 (i)')
    plt.ylabel('cars at loc 2 (j)')

    plt.subplot(1, 2, 2)
    im2 = plt.imshow(policy.T, origin='lower', cmap='coolwarm')
    plt.colorbar(im2, fraction=0.046, pad=0.04)
    plt.title(f"{title_prefix}: policy (move from 1→2)")
    plt.xlabel('cars at loc 1 (i)')
    plt.ylabel('cars at loc 2 (j)')
    plt.tight_layout()
    plt.show()


# Demo run (scaled-down for speed). You can increase max_cars / poisson_max later.
jacks_policy, jacks_V, jacks_hist = jacks_policy_iteration(max_cars=10, max_move=5, poisson_max=8, theta=1e-3)
print('Jack\'s policy iteration outer loops:', len(jacks_hist))
plot_jacks_policy_and_value(jacks_policy, jacks_V, title_prefix="Jack's (scaled)")

## More tasks (ideas to extend this notebook)

If you want “the more material the better”, here are additional DP/MDP practice tasks that still stay in the Sutton & Barto Ch. 3–4 zone:

1. **Policy evaluation speed**: compare `in_place=True` vs `False` for Gridworld and plot deltas.
2. **Asynchronous updates**: sweep states in random order and compare convergence.
3. **State aggregation**: group Gridworld states (e.g., by Manhattan distance to terminal) and evaluate an aggregated approximate value function.
4. **Jack’s scaling study**: run Jack’s with `max_cars=8,10,12` and compare runtime and policy structure.
5. **Parameter sweeps**: for Gambler and Jack’s, sweep $p_h$ (or Poisson rates) and track how the greedy policy changes.