In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# The Bellman Equation: Breaking the Future into One Step at a Time

*Part 2 of the Vizuara series on Value Functions and Q-Learning*
*Estimated time: 40 minutes*

## 1. Why Does This Matter?

In the previous notebook, we computed value functions by running policy evaluation until convergence. But we never asked the deeper question: **why does that iterative process converge at all?**

The answer lies in the Bellman equation -- a recursive relationship discovered by Richard Bellman in the 1950s that says: "You do not need to think about the entire future. Just think about the next step, and let the recursion handle the rest."

This single insight is the mathematical foundation behind virtually every reinforcement learning algorithm ever invented -- from tabular Q-learning to the deep RL agents that play Atari and control robots.

By the end of this notebook, you will:
- Derive and implement the Bellman equation for V and Q
- Solve the Bellman optimality equation to find the best policy
- Implement value iteration from scratch
- Watch optimal policies emerge from pure computation

## 2. Building Intuition

Imagine you are playing Pac-Man with a fixed strategy: if a ghost is nearby, move away; if there is a pellet, move towards it. This strategy is your policy.

Now, someone asks: "If you start at this specific maze location with this strategy, how many points will you score on average?"

You could simulate the entire game from that point thousands of times and average the scores. But the Bellman equation offers a shortcut: **just look at what happens in the next step, and then add the value of wherever that step takes you.**

This is like asking for directions. Instead of getting a complete route from NYC to LA, you just ask: "Which highway should I take next?" and then ask again when you reach the next junction. One step at a time.

### Think About This

If you know the value of every state your neighbor leads to, can you figure out the value of your own state? If so, and if everyone does this simultaneously, would the values eventually stabilize? This is exactly what the Bellman equation guarantees.

## 3. The Mathematics

### The Bellman Equation for V

$$V^{\pi}(s) = \sum_{a} \pi(a \mid s) \sum_{s', r} p(s', r \mid s, a) \left[ r + \gamma \, V^{\pi}(s') \right]$$

In words: the value of state $s$ equals the sum over all actions (weighted by the policy) of the sum over all possible next states (weighted by transition probabilities) of the immediate reward plus the discounted value of the next state.

For a deterministic environment with a deterministic policy, this simplifies to:

$$V^{\pi}(s) = r(s, \pi(s)) + \gamma \cdot V^{\pi}(s')$$

This is the recursive structure: the value of the current state depends on the value of the next state.

### The Bellman Optimality Equation

$$V^*(s) = \max_{a} \sum_{s', r} p(s', r \mid s, a) \left[ r + \gamma \, V^*(s') \right]$$

Instead of averaging over actions according to a policy, we take the **maximum** -- always pick the best action. This gives us the optimal value function.

### The Bellman Optimality Equation for Q

$$Q^*(s, a) = \sum_{s', r} p(s', r \mid s, a) \left[ r + \gamma \, \max_{a'} Q^*(s', a') \right]$$

This is the key equation for Q-learning: the optimal Q-value for taking action $a$ in state $s$ equals the reward plus the discounted maximum Q-value in the next state.

## 4. Let's Build It -- Component by Component

### 4.1 Setup: Our Grid World

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

class GridWorld:
    """A simple deterministic grid world."""

    def __init__(self, rows=4, cols=4, goal=(3, 3), obstacles=None,
                 step_reward=-1.0, goal_reward=10.0):
        self.rows = rows
        self.cols = cols
        self.goal = goal
        self.obstacles = obstacles or [(1, 1)]
        self.step_reward = step_reward
        self.goal_reward = goal_reward
        self.actions = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        self.action_names = ['Up', 'Right', 'Down', 'Left']
        self.action_symbols = ['^', '>', 'v', '<']
        self.n_actions = 4

    def is_valid(self, r, c):
        return (0 <= r < self.rows and 0 <= c < self.cols
                and (r, c) not in self.obstacles)

    def step(self, state, action):
        r, c = state
        dr, dc = self.actions[action]
        nr, nc = r + dr, c + dc
        if not self.is_valid(nr, nc):
            nr, nc = r, c
        reward = self.goal_reward if (nr, nc) == self.goal else self.step_reward
        done = (nr, nc) == self.goal
        return (nr, nc), reward, done

    def get_all_states(self):
        return [(r, c) for r in range(self.rows)
                for c in range(self.cols) if (r, c) not in self.obstacles]


env = GridWorld(rows=4, cols=4, goal=(3, 3), obstacles=[(1, 1)])
print(f"Grid: {env.rows}x{env.cols}, Goal: {env.goal}, Obstacles: {env.obstacles}")

### 4.2 The Bellman Equation in Action: One-Step Backup

Let us implement the core of the Bellman equation -- the one-step backup.

In [None]:
def bellman_backup_v(env, V, state, policy, gamma=0.9):
    """
    Compute one Bellman backup for V^pi(s).

    V^pi(s) = sum_a pi(a|s) * [r(s,a) + gamma * V^pi(s')]

    For a deterministic policy: V^pi(s) = r(s, pi(s)) + gamma * V^pi(s')
    """
    if state == env.goal:
        return 0.0

    action = policy[state]
    next_state, reward, done = env.step(state, action)

    if done:
        return reward
    else:
        return reward + gamma * V.get(next_state, 0.0)


def bellman_backup_v_stochastic(env, V, state, gamma=0.9):
    """
    Compute one Bellman backup for V(s) under a random policy.

    V(s) = (1/|A|) * sum_a [r(s,a) + gamma * V(s')]
    """
    if state == env.goal:
        return 0.0

    total = 0.0
    for a in range(env.n_actions):
        next_state, reward, done = env.step(state, a)
        if done:
            total += reward
        else:
            total += reward + gamma * V.get(next_state, 0.0)

    return total / env.n_actions


# Demonstrate the backup step by step
V = {s: 0.0 for s in env.get_all_states()}

# Set some initial values to show the backup
V[(2, 3)] = 5.0
V[(3, 2)] = 3.0

state = (2, 2)
print(f"State: {state}")
print(f"Neighbors' values: right={(2,3)}:{V[(2,3)]:.1f}, down={(3,2)}:{V[(3,2)]:.1f}")
print()

# Show the Bellman backup for a random policy at (2,2)
new_v = bellman_backup_v_stochastic(env, V, state, gamma=0.9)
print(f"Bellman backup (random policy) at {state}:")
print(f"  V({state}) = (1/4) * sum of [r + 0.9 * V(s')]")
for a in range(env.n_actions):
    ns, r, d = env.step(state, a)
    v_next = V.get(ns, 0.0)
    print(f"  Action {env.action_names[a]:>5}: next={ns}, r={r:.1f}, V(next)={v_next:.1f}, "
          f"contribution={r + 0.9 * v_next:.2f}")
print(f"  Average = {new_v:.2f}")

### 4.3 Value Iteration: Finding the Optimal V*

Value iteration uses the Bellman optimality equation repeatedly until convergence.

In [None]:
def value_iteration(env, gamma=0.9, theta=1e-6, max_iters=1000):
    """
    Find V*(s) using value iteration.

    Update rule: V(s) <- max_a [r(s,a) + gamma * V(s')]

    This applies the Bellman optimality equation as an update.
    """
    states = env.get_all_states()
    V = {s: 0.0 for s in states}
    history = []

    for iteration in range(max_iters):
        delta = 0.0
        V_snapshot = dict(V)

        for s in states:
            if s == env.goal:
                continue

            old_v = V[s]

            # Bellman optimality backup: max over all actions
            best_value = float('-inf')
            for a in range(env.n_actions):
                next_s, reward, done = env.step(s, a)
                if done:
                    value = reward
                else:
                    value = reward + gamma * V.get(next_s, 0.0)
                best_value = max(best_value, value)

            V[s] = best_value
            delta = max(delta, abs(old_v - V[s]))

        history.append(dict(V))

        if delta < theta:
            print(f"Value iteration converged in {iteration + 1} iterations (delta={delta:.8f})")
            break

    return V, history


V_star, history = value_iteration(env, gamma=0.9)

print("\nOptimal Values V*(s):")
for r in range(env.rows):
    row_str = ""
    for c in range(env.cols):
        if (r, c) in env.obstacles:
            row_str += "  XXX  "
        else:
            row_str += f" {V_star.get((r,c), 0.0):5.1f} "
    print(row_str)

In [None]:
# Visualization: Watch value iteration converge
snapshots_to_show = [0, 1, 2, 5, len(history)-1]
snapshots_to_show = [i for i in snapshots_to_show if i < len(history)]

fig, axes = plt.subplots(1, len(snapshots_to_show), figsize=(4*len(snapshots_to_show), 4))
if len(snapshots_to_show) == 1:
    axes = [axes]

cmap = LinearSegmentedColormap.from_list('vi', ['#f7fbff', '#2171b5'])

for idx, snap_idx in enumerate(snapshots_to_show):
    V_snap = history[snap_idx]
    grid = np.full((env.rows, env.cols), np.nan)
    for (r, c), v in V_snap.items():
        grid[r, c] = v

    axes[idx].imshow(grid, cmap=cmap, vmin=-10, vmax=10)
    axes[idx].set_title(f'Iteration {snap_idx + 1}', fontsize=12, fontweight='bold')

    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) in env.obstacles:
                axes[idx].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True, color='gray'))
            elif not np.isnan(grid[r, c]):
                axes[idx].text(c, r, f'{grid[r,c]:.1f}', ha='center', va='center',
                               fontsize=10, fontweight='bold',
                               color='white' if abs(grid[r,c]) > 3 else 'black')

    axes[idx].set_xticks(range(env.cols))
    axes[idx].set_yticks(range(env.rows))
    axes[idx].grid(True, linewidth=1, color='white')

plt.suptitle('Value Iteration: Watch V* Emerge', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### 4.4 Extract the Optimal Policy from V*

In [None]:
def extract_policy(env, V_star, gamma=0.9):
    """
    Extract the optimal policy from V*.

    pi*(s) = argmax_a [r(s,a) + gamma * V*(s')]
    """
    policy = {}
    for s in env.get_all_states():
        if s == env.goal:
            policy[s] = 0  # Arbitrary for terminal
            continue

        best_action = 0
        best_value = float('-inf')

        for a in range(env.n_actions):
            next_s, reward, done = env.step(s, a)
            if done:
                value = reward
            else:
                value = reward + gamma * V_star.get(next_s, 0.0)

            if value > best_value:
                best_value = value
                best_action = a

        policy[s] = best_action

    return policy


optimal_policy = extract_policy(env, V_star, gamma=0.9)

print("Optimal Policy (arrows show best action):")
for r in range(env.rows):
    row_str = ""
    for c in range(env.cols):
        if (r, c) in env.obstacles:
            row_str += "  X  "
        elif (r, c) == env.goal:
            row_str += "  G  "
        else:
            row_str += f"  {env.action_symbols[optimal_policy[(r,c)]]}  "
    print(row_str)

## 5. Your Turn

### TODO: Implement Q-Value Iteration

Instead of finding V*, find Q* directly using the Bellman optimality equation for Q.

In [None]:
def q_value_iteration(env, gamma=0.9, theta=1e-6, max_iters=1000):
    """
    Find Q*(s, a) for all state-action pairs using value iteration.

    Update rule: Q(s, a) <- r(s,a) + gamma * max_{a'} Q(s', a')

    This directly solves the Bellman optimality equation for Q.

    Returns:
        dict mapping state -> dict mapping action -> Q-value
    """
    states = env.get_all_states()
    Q = {s: {a: 0.0 for a in range(env.n_actions)} for s in states}

    for iteration in range(max_iters):
        delta = 0.0

        for s in states:
            if s == env.goal:
                continue

            for a in range(env.n_actions):
                old_q = Q[s][a]

                # ============ TODO ============
                # Compute the new Q-value using the Bellman optimality equation:
                # Q(s, a) = r(s,a) + gamma * max_{a'} Q(s', a')
                #
                # Step 1: Use env.step(s, a) to get (next_state, reward, done)
                # Step 2: If done, Q(s,a) = reward
                # Step 3: Otherwise, Q(s,a) = reward + gamma * max over a' of Q(s', a')
                # ==============================

                new_q = 0.0  # YOUR CODE HERE

                Q[s][a] = new_q
                delta = max(delta, abs(old_q - new_q))

        if delta < theta:
            print(f"Q-value iteration converged in {iteration + 1} iterations")
            break

    return Q

In [None]:
# Verification
Q_star = q_value_iteration(env, gamma=0.9)

# Check: V*(s) should equal max_a Q*(s, a)
for s in env.get_all_states():
    v_from_q = max(Q_star[s].values())
    v_from_vi = V_star[s]
    assert abs(v_from_q - v_from_vi) < 0.01, f"Mismatch at {s}: Q gives {v_from_q:.3f}, V gives {v_from_vi:.3f}"

print("All checks passed! Q* is consistent with V*.")
print(f"\nQ* at state (0,0):")
for a in range(env.n_actions):
    print(f"  {env.action_names[a]:>5}: {Q_star[(0,0)][a]:.3f}")

## 6. Putting It All Together

In [None]:
# Side-by-side comparison: suboptimal vs optimal policy

# Suboptimal: always go right
suboptimal_policy = {s: 1 for s in env.get_all_states()}  # Always right

from copy import deepcopy

def policy_evaluation_full(env, policy, gamma=0.9, theta=1e-6, max_iters=1000):
    states = env.get_all_states()
    V = {s: 0.0 for s in states}
    for _ in range(max_iters):
        delta = 0
        for s in states:
            if s == env.goal: continue
            old_v = V[s]
            ns, r, d = env.step(s, policy[s])
            V[s] = r + (0 if d else gamma * V.get(ns, 0))
            delta = max(delta, abs(old_v - V[s]))
        if delta < theta: break
    return V

V_subopt = policy_evaluation_full(env, suboptimal_policy)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for idx, (title, V_show, pol) in enumerate([
    ("Suboptimal: Always Right", V_subopt, suboptimal_policy),
    ("Optimal: Value Iteration", V_star, optimal_policy),
]):
    grid = np.full((env.rows, env.cols), np.nan)
    for (r, c), v in V_show.items():
        grid[r, c] = v

    color = 'Reds' if idx == 0 else 'Blues'
    axes[idx].imshow(grid, cmap=color, interpolation='nearest')
    axes[idx].set_title(title, fontsize=14, fontweight='bold')

    arrow_dx = [0, 0.35, 0, -0.35]
    arrow_dy = [-0.35, 0, 0.35, 0]

    for (r, c) in env.get_all_states():
        if (r, c) == env.goal:
            axes[idx].text(c, r, 'G', ha='center', va='center', fontsize=14,
                          fontweight='bold', color='gold')
            continue

        a = pol[(r, c)]
        axes[idx].annotate('', xy=(c + arrow_dx[a], r + arrow_dy[a]),
                           xytext=(c, r),
                           arrowprops=dict(arrowstyle='->', lw=2,
                                           color='darkred' if idx == 0 else 'darkblue'))
        axes[idx].text(c, r + 0.35, f'{V_show[(r,c)]:.1f}', ha='center', va='center',
                      fontsize=9, fontweight='bold', color='white')

    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) in env.obstacles:
                axes[idx].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, color='gray'))

    axes[idx].set_xticks(range(env.cols))
    axes[idx].set_yticks(range(env.rows))
    axes[idx].grid(True, linewidth=1, color='white')

plt.suptitle('Suboptimal vs Optimal Policy', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Training and Results

Let us track how value iteration converges over iterations.

In [None]:
# Convergence analysis
max_value_per_iter = []
for V_snap in history:
    vals = [v for s, v in V_snap.items() if s != env.goal]
    max_value_per_iter.append(max(vals) if vals else 0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(range(1, len(max_value_per_iter)+1), max_value_per_iter, 'b-o', markersize=4)
ax1.set_xlabel('Iteration', fontsize=12)
ax1.set_ylabel('Max V(s)', fontsize=12)
ax1.set_title('Value Iteration Convergence', fontsize=13, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Show delta (change) per iteration
deltas = []
for i in range(1, len(history)):
    d = max(abs(history[i][s] - history[i-1][s]) for s in env.get_all_states())
    deltas.append(d)

ax2.semilogy(range(2, len(history)+1), deltas, 'r-o', markersize=4)
ax2.set_xlabel('Iteration', fontsize=12)
ax2.set_ylabel('Max Delta (log scale)', fontsize=12)
ax2.set_title('Convergence Speed', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Converged in {len(history)} iterations")
print(f"Final max delta: {deltas[-1] if deltas else 0:.2e}")

## 8. Final Output

In [None]:
# Comprehensive summary visualization
fig = plt.figure(figsize=(16, 8))

# Left: Optimal values with policy
ax1 = fig.add_subplot(121)
grid = np.full((env.rows, env.cols), np.nan)
for (r, c), v in V_star.items():
    grid[r, c] = v

cmap = LinearSegmentedColormap.from_list('opt', ['#f7fbff', '#08519c'])
ax1.imshow(grid, cmap=cmap)
ax1.set_title('Optimal Policy (V* + arrows)', fontsize=14, fontweight='bold')

arrow_dx = [0, 0.35, 0, -0.35]
arrow_dy = [-0.35, 0, 0.35, 0]

for (r, c) in env.get_all_states():
    if (r, c) == env.goal:
        ax1.text(c, r, 'GOAL', ha='center', va='center', fontsize=10,
                fontweight='bold', color='gold',
                bbox=dict(boxstyle='round', facecolor='green', alpha=0.8))
    elif (r, c) in env.obstacles:
        ax1.add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, color='gray'))
    else:
        a = optimal_policy[(r, c)]
        ax1.annotate('', xy=(c + arrow_dx[a], r + arrow_dy[a]),
                     xytext=(c, r),
                     arrowprops=dict(arrowstyle='->', color='white', lw=2.5))
        ax1.text(c, r + 0.35, f'{V_star[(r,c)]:.1f}', ha='center', va='center',
                fontsize=10, fontweight='bold', color='white')

ax1.set_xticks(range(env.cols))
ax1.set_yticks(range(env.rows))
ax1.grid(True, linewidth=1, color='white', alpha=0.3)

# Right: Q* heatmap for all actions
ax2 = fig.add_subplot(122)
q_grid = np.zeros((env.rows * 2, env.cols * 2))
for (r, c), q_vals in Q_star.items():
    for a, q in q_vals.items():
        # Map actions to sub-cells
        sub_r = r * 2 + (1 if a == 2 else 0 if a == 0 else 0 if a in [1,3] else 0)
        sub_c = c * 2 + (1 if a == 1 else 0 if a == 3 else 0 if a in [0,2] else 0)
        if a == 0: sub_r, sub_c = r*2, c*2  # Up -> top-left
        elif a == 1: sub_r, sub_c = r*2, c*2+1  # Right -> top-right
        elif a == 2: sub_r, sub_c = r*2+1, c*2  # Down -> bottom-left
        elif a == 3: sub_r, sub_c = r*2+1, c*2+1  # Left -> bottom-right
        q_grid[sub_r, sub_c] = q

im = ax2.imshow(q_grid, cmap='RdYlBu', interpolation='nearest')
ax2.set_title('Q* values (4 sub-cells per state)', fontsize=14, fontweight='bold')
plt.colorbar(im, ax=ax2, shrink=0.8)

plt.suptitle('Bellman Optimality -- Complete Solution', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("Congratulations! You have solved the Bellman equations from scratch!")
print("Next up: learning these values from experience with Q-Learning.")

## 9. Reflection and Next Steps

### Reflection Questions
1. Value iteration converges because the Bellman operator is a contraction. What does "contraction" mean intuitively?
2. Why does the Bellman optimality equation use max instead of sum (compared to the standard Bellman equation)?
3. In our implementation, we needed the transition dynamics (env.step). What if we did not have access to them?

### Optional Challenges
1. Implement policy iteration (alternating policy evaluation and policy improvement) and compare convergence speed to value iteration.
2. Add stochastic transitions (the agent slips 20% of the time) and re-run value iteration. How does the optimal policy change?
3. Implement the Bellman equation for a continuous state space using function approximation (a simple neural network).