In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Value Functions: Teaching Machines to Evaluate Positions

*Part 1 of the Vizuara series on Value Functions and Q-Learning*
*Estimated time: 35 minutes*

## 1. Why Does This Matter?

Every decision-making agent -- whether it is a chess engine, a robot navigating a warehouse, or a recommendation system -- needs to answer one fundamental question: **"How good is my current situation?"**

Value functions are the mathematical tool that answers this question. They assign a number to every possible state, telling the agent how much total reward it can expect from that point onward. This simple idea turns out to be one of the most powerful concepts in all of reinforcement learning.

By the end of this notebook, you will:
- Implement state value functions and action value functions from scratch
- Visualize value landscapes on grid worlds
- Understand discounting and why it matters
- See how Q-values guide action selection

Let us build the intuition, then the code.

## 2. Building Intuition

Imagine you are playing a board game on a grid. Your goal is to reach the treasure at the top-right corner.

At any moment, you are standing on a particular cell. A natural question arises: **"How good is it to be in this cell?"**

If you are right next to the treasure, the answer is obvious -- it is very good. If you are far away with obstacles in the way, less so. But here is the key insight: even cells that are far from the treasure can be "good" if they lie on a clear path towards it.

This is exactly what a value function captures -- a number for every position that says "this is how much reward you can expect from here."

### Think About This

If you are at a crossroads and one path leads to a reward of +10 in 2 steps, while another path leads to a reward of +100 in 50 steps, which path is "better"? Does it depend on how patient you are? This question motivates the concept of discounting, which we will explore next.

## 3. The Mathematics

### State Value Function

The state value function $V^{\pi}(s)$ tells us the expected total future reward starting from state $s$ and following policy $\pi$:

$$V^{\pi}(s) = \mathbb{E}_{\pi}[G_t \mid S_t = s]$$

Computationally, this says: "Start at state $s$. Follow policy $\pi$ to choose actions. Add up all the rewards you collect. Average over many runs. That average is the value."

### Discounted Return

Future rewards are worth less than immediate rewards. We capture this using a discount factor $\gamma \in [0, 1]$:

$$G_t = r_t + \gamma \cdot r_{t+1} + \gamma^2 \cdot r_{t+2} + \gamma^3 \cdot r_{t+3} + \cdots$$

When $\gamma = 0.9$, a reward of 10 received 5 steps later is worth only $10 \times 0.9^5 = 5.9$ today. The higher the gamma, the more the agent cares about the future.

### Action Value Function (Q-Function)

The action value function $Q^{\pi}(s, a)$ tells us the expected return if we start in state $s$, take action $a$, and then follow policy $\pi$:

$$Q^{\pi}(s, a) = \mathbb{E}_{\pi}[G_t \mid S_t = s, A_t = a]$$

The relationship between V and Q is:

$$V^{\pi}(s) = \sum_{a} \pi(a \mid s) \cdot Q^{\pi}(s, a)$$

This says: the value of a state is the average Q-value over all actions, weighted by how likely each action is under our policy.

## 4. Let's Build It -- Component by Component

### 4.1 Define a Simple Grid World

We will start by building a simple grid world environment from scratch. No libraries -- just numpy.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap

class GridWorld:
    """A simple grid world environment."""

    def __init__(self, rows=5, cols=5, goal=(4, 4), obstacles=None, step_reward=-0.1, goal_reward=10.0):
        self.rows = rows
        self.cols = cols
        self.goal = goal
        self.obstacles = obstacles or [(2, 1), (2, 3)]
        self.step_reward = step_reward
        self.goal_reward = goal_reward

        # Actions: 0=up, 1=right, 2=down, 3=left
        self.actions = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        self.action_names = ['Up', 'Right', 'Down', 'Left']
        self.n_actions = len(self.actions)
        self.n_states = rows * cols

    def is_valid(self, r, c):
        return (0 <= r < self.rows and 0 <= c < self.cols
                and (r, c) not in self.obstacles)

    def step(self, state, action):
        """Take an action, return (next_state, reward, done)."""
        r, c = state
        dr, dc = self.actions[action]
        nr, nc = r + dr, c + dc

        if not self.is_valid(nr, nc):
            nr, nc = r, c  # Stay in place if hitting wall or obstacle

        reward = self.step_reward
        done = False
        if (nr, nc) == self.goal:
            reward = self.goal_reward
            done = True

        return (nr, nc), reward, done

    def get_all_states(self):
        """Return all non-obstacle states."""
        states = []
        for r in range(self.rows):
            for c in range(self.cols):
                if (r, c) not in self.obstacles:
                    states.append((r, c))
        return states


env = GridWorld()
print(f"Grid: {env.rows}x{env.cols}")
print(f"Goal: {env.goal}")
print(f"Obstacles: {env.obstacles}")
print(f"Total valid states: {len(env.get_all_states())}")
print(f"Actions: {env.action_names}")

### 4.2 Compute Discounted Returns by Hand

Before using any algorithm, let us compute discounted returns manually to build intuition.

In [None]:
def compute_discounted_return(rewards, gamma=0.9):
    """
    Compute the discounted return from a sequence of rewards.

    G_t = r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + ...
    """
    G = 0.0
    for t in reversed(range(len(rewards))):
        G = rewards[t] + gamma * G
    return G


# Example 1: Simple chain A -> B -> C
rewards_abc = [1.0, 2.0]  # r from A->B = 1, r from B->C = 2
gamma = 0.9

G_from_A = compute_discounted_return(rewards_abc, gamma)
G_from_B = compute_discounted_return([2.0], gamma)

print("=== Three-State Chain (A -> B -> C) ===")
print(f"Rewards: A->{rewards_abc[0]} -> B->{rewards_abc[1]} -> C (terminal)")
print(f"Gamma: {gamma}")
print(f"V(A) = {rewards_abc[0]} + {gamma} x {rewards_abc[1]} = {G_from_A:.2f}")
print(f"V(B) = {rewards_abc[1]:.2f}")
print(f"V(C) = 0.00 (terminal)")

# Example 2: Longer trajectory
rewards_long = [1, 2, 3, 4, 5]
G_long = compute_discounted_return(rewards_long, gamma)
print(f"\n=== Longer Trajectory ===")
print(f"Rewards: {rewards_long}")
print(f"Discounted return: {G_long:.2f}")

# Show how each reward contributes
print("\nBreakdown:")
for t, r in enumerate(rewards_long):
    contribution = r * (gamma ** t)
    print(f"  t={t}: reward={r}, gamma^{t}={gamma**t:.4f}, contribution={contribution:.4f}")

### 4.3 Visualize a Value Function

In [None]:
def visualize_values(env, values, title="State Values"):
    """Visualize value function as a heatmap on the grid."""
    grid = np.full((env.rows, env.cols), np.nan)

    for (r, c), v in values.items():
        grid[r, c] = v

    fig, ax = plt.subplots(1, 1, figsize=(7, 7))

    cmap = LinearSegmentedColormap.from_list('vizuara', ['#fee0d2', '#fc9272', '#de2d26'])
    im = ax.imshow(grid, cmap=cmap, interpolation='nearest')

    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) in env.obstacles:
                ax.add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True, color='gray', alpha=0.7))
                ax.text(c, r, 'X', ha='center', va='center', fontsize=14, fontweight='bold', color='white')
            elif (r, c) == env.goal:
                ax.text(c, r, f'{grid[r,c]:.1f}\nGOAL', ha='center', va='center', fontsize=10, fontweight='bold')
            elif not np.isnan(grid[r, c]):
                ax.text(c, r, f'{grid[r,c]:.1f}', ha='center', va='center', fontsize=12, fontweight='bold')

    ax.set_xticks(range(env.cols))
    ax.set_yticks(range(env.rows))
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, linewidth=2, color='white')
    plt.colorbar(im, ax=ax, shrink=0.8, label='Value')
    plt.tight_layout()
    plt.show()


# Create a simple hand-crafted value function for visualization
# (We will compute real ones shortly)
manual_values = {}
for r in range(5):
    for c in range(5):
        if (r, c) not in env.obstacles:
            # Manhattan distance heuristic (rough approximation)
            dist = abs(r - env.goal[0]) + abs(c - env.goal[1])
            manual_values[(r, c)] = max(0, 10 - dist * 1.5)

visualize_values(env, manual_values, "Approximate Values (Distance Heuristic)")

### 4.4 Policy Evaluation: Computing V(s) for a Given Policy

Now let us compute the true value function for a given policy using iterative policy evaluation.

In [None]:
def policy_evaluation(env, policy, gamma=0.9, theta=1e-6, max_iters=1000):
    """
    Compute V^pi(s) for all states using iterative policy evaluation.

    The Bellman equation for V^pi is:
    V^pi(s) = sum_a pi(a|s) * sum_{s'} p(s'|s,a) * [r + gamma * V^pi(s')]

    Since our environment is deterministic, p(s'|s,a) = 1 for one s'.
    """
    states = env.get_all_states()
    V = {s: 0.0 for s in states}

    for iteration in range(max_iters):
        delta = 0.0

        for s in states:
            if s == env.goal:
                continue  # Terminal state has value 0

            old_v = V[s]
            action = policy[s]
            next_s, reward, done = env.step(s, action)

            if done:
                V[s] = reward
            else:
                V[s] = reward + gamma * V.get(next_s, 0.0)

            delta = max(delta, abs(old_v - V[s]))

        if delta < theta:
            print(f"Policy evaluation converged in {iteration + 1} iterations (delta={delta:.8f})")
            break

    return V


# Define a simple policy: always go RIGHT, then DOWN
def simple_policy(env):
    """A simple policy: go right until wall, then go down."""
    policy = {}
    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) not in env.obstacles:
                if c < env.cols - 1 and (r, c + 1) not in env.obstacles:
                    policy[(r, c)] = 1  # Right
                else:
                    policy[(r, c)] = 2  # Down
    return policy


policy = simple_policy(env)
V = policy_evaluation(env, policy, gamma=0.9)

print("\nState Values under 'go right then down' policy:")
for r in range(env.rows):
    row_str = ""
    for c in range(env.cols):
        if (r, c) in env.obstacles:
            row_str += "  XXX  "
        else:
            row_str += f" {V.get((r,c), 0.0):5.2f} "
    print(row_str)

In [None]:
visualize_values(env, V, "V^pi: 'Go Right Then Down' Policy")

### 4.5 Computing Q-Values from V

In [None]:
def compute_q_values(env, V, gamma=0.9):
    """
    Compute Q(s, a) for all state-action pairs from V(s).

    Q^pi(s, a) = r(s, a) + gamma * V^pi(s')
    """
    Q = {}
    states = env.get_all_states()

    for s in states:
        if s == env.goal:
            Q[s] = {a: 0.0 for a in range(env.n_actions)}
            continue

        Q[s] = {}
        for a in range(env.n_actions):
            next_s, reward, done = env.step(s, a)
            if done:
                Q[s][a] = reward
            else:
                Q[s][a] = reward + gamma * V.get(next_s, 0.0)

    return Q


Q = compute_q_values(env, V, gamma=0.9)

# Show Q-values for a specific state
state = (0, 0)
print(f"Q-values at state {state}:")
for a in range(env.n_actions):
    print(f"  {env.action_names[a]:>5}: Q = {Q[state][a]:.3f}")

best_action = max(Q[state], key=Q[state].get)
print(f"\nBest action at {state}: {env.action_names[best_action]} (Q = {Q[state][best_action]:.3f})")

## 5. Your Turn

### TODO: Implement a Random Policy Evaluator

Below, you need to implement policy evaluation for a *stochastic* (random) policy where the agent picks each action with equal probability.

In [None]:
def evaluate_random_policy(env, gamma=0.9, theta=1e-6, max_iters=1000):
    """
    Compute V(s) for a uniformly random policy.

    Under a random policy, pi(a|s) = 1/n_actions for all a.
    So: V(s) = (1/n_actions) * sum_a [r(s,a) + gamma * V(s')]

    Returns:
        dict mapping state -> value
    """
    states = env.get_all_states()
    V = {s: 0.0 for s in states}
    n_actions = env.n_actions

    for iteration in range(max_iters):
        delta = 0.0

        for s in states:
            if s == env.goal:
                continue

            old_v = V[s]
            # ============ TODO ============
            # Compute the new value of state s under a random policy.
            # For each action a, compute: r + gamma * V(s')
            # Then average across all actions (since each has probability 1/n_actions)
            # Hint: Use env.step(s, a) to get (next_state, reward, done)
            # ==============================

            new_v = 0.0  # YOUR CODE HERE

            V[s] = new_v
            delta = max(delta, abs(old_v - V[s]))

        if delta < theta:
            print(f"Converged in {iteration + 1} iterations")
            break

    return V

In [None]:
# Verification
V_random = evaluate_random_policy(env, gamma=0.9)

# The random policy should have lower values than the directed policy
assert V_random[(0, 0)] < V[(0, 0)], "Random policy should have lower value than directed policy!"
assert V_random[env.goal] == 0.0, "Goal state should have value 0"
print("All checks passed!")
print(f"V_random(0,0) = {V_random[(0,0)]:.3f} < V_directed(0,0) = {V[(0,0)]:.3f}")

visualize_values(env, V_random, "V: Random Policy (Your Implementation)")

## 6. Putting It All Together

Let us now visualize the full picture: values, Q-values, and the implied greedy policy all on one grid.

In [None]:
def visualize_policy_and_values(env, V, Q, title="Policy and Values"):
    """Visualize the greedy policy arrows overlaid on value heatmap."""
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))

    # Value heatmap
    grid = np.full((env.rows, env.cols), np.nan)
    for (r, c), v in V.items():
        grid[r, c] = v

    cmap = LinearSegmentedColormap.from_list('vizuara', ['#f7fbff', '#6baed6', '#08519c'])
    im = ax.imshow(grid, cmap=cmap, interpolation='nearest')

    # Arrow directions for each action
    arrow_dx = [0, 0.3, 0, -0.3]  # Up, Right, Down, Left
    arrow_dy = [-0.3, 0, 0.3, 0]

    for (r, c), q_vals in Q.items():
        if (r, c) == env.goal:
            ax.text(c, r, 'GOAL', ha='center', va='center', fontsize=10,
                    fontweight='bold', color='gold',
                    bbox=dict(boxstyle='round', facecolor='green', alpha=0.8))
            continue
        if (r, c) in env.obstacles:
            continue

        # Draw greedy policy arrow
        best_a = max(q_vals, key=q_vals.get)
        ax.annotate('', xy=(c + arrow_dx[best_a], r + arrow_dy[best_a]),
                     xytext=(c, r),
                     arrowprops=dict(arrowstyle='->', color='red', lw=2))

        ax.text(c, r + 0.35, f'{V[(r,c)]:.1f}', ha='center', va='center',
                fontsize=9, color='white', fontweight='bold')

    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) in env.obstacles:
                ax.add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True,
                                           color='gray', alpha=0.8))
                ax.text(c, r, 'X', ha='center', va='center', fontsize=16,
                        fontweight='bold', color='white')

    ax.set_xticks(range(env.cols))
    ax.set_yticks(range(env.rows))
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, linewidth=2, color='white', alpha=0.5)
    plt.colorbar(im, ax=ax, shrink=0.8, label='Value')
    plt.tight_layout()
    plt.show()


visualize_policy_and_values(env, V, Q, "Greedy Policy from V^pi")

## 7. Training and Results

Let us compare value functions across different discount factors to see how gamma affects the agent's perspective.

In [None]:
gammas = [0.5, 0.7, 0.9, 0.99]
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for idx, gamma in enumerate(gammas):
    V_g = policy_evaluation(env, policy, gamma=gamma)

    grid = np.full((env.rows, env.cols), np.nan)
    for (r, c), v in V_g.items():
        grid[r, c] = v

    cmap = LinearSegmentedColormap.from_list('vizuara', ['#fff5eb', '#fd8d3c', '#d94701'])
    im = axes[idx].imshow(grid, cmap=cmap, interpolation='nearest')
    axes[idx].set_title(f'gamma = {gamma}', fontsize=13, fontweight='bold')

    for r in range(env.rows):
        for c in range(env.cols):
            if (r, c) in env.obstacles:
                axes[idx].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True,
                                                   color='gray', alpha=0.7))
            elif not np.isnan(grid[r, c]):
                axes[idx].text(c, r, f'{grid[r,c]:.1f}', ha='center', va='center',
                               fontsize=9, fontweight='bold')

    axes[idx].set_xticks(range(env.cols))
    axes[idx].set_yticks(range(env.rows))
    axes[idx].grid(True, linewidth=1, color='white')

plt.suptitle('How Discount Factor Affects Value Functions', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("Notice how:")
print("- Low gamma (0.5): Only states very close to the goal have high values")
print("- High gamma (0.99): Distant states also have significant value")
print("- Gamma controls the 'horizon' -- how far ahead the agent looks")

## 8. Final Output

In [None]:
# Generate the comprehensive summary figure
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Panel 1: State Values
grid_v = np.full((env.rows, env.cols), np.nan)
for (r, c), v in V.items():
    grid_v[r, c] = v
cmap1 = LinearSegmentedColormap.from_list('v', ['#f7fbff', '#2171b5'])
axes[0].imshow(grid_v, cmap=cmap1)
axes[0].set_title('State Values V(s)', fontsize=13, fontweight='bold')
for r in range(env.rows):
    for c in range(env.cols):
        if (r, c) in env.obstacles:
            axes[0].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True, color='gray'))
        elif not np.isnan(grid_v[r, c]):
            axes[0].text(c, r, f'{grid_v[r,c]:.1f}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Panel 2: Best Q-values
grid_q = np.full((env.rows, env.cols), np.nan)
for (r, c) in Q:
    if (r, c) not in env.obstacles:
        grid_q[r, c] = max(Q[(r, c)].values())
cmap2 = LinearSegmentedColormap.from_list('q', ['#fff5f0', '#cb181d'])
axes[1].imshow(grid_q, cmap=cmap2)
axes[1].set_title('Best Q-values max_a Q(s,a)', fontsize=13, fontweight='bold')
for r in range(env.rows):
    for c in range(env.cols):
        if (r, c) in env.obstacles:
            axes[1].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True, color='gray'))
        elif not np.isnan(grid_q[r, c]):
            axes[1].text(c, r, f'{grid_q[r,c]:.1f}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Panel 3: Greedy Policy Arrows
axes[2].set_xlim(-0.5, env.cols - 0.5)
axes[2].set_ylim(env.rows - 0.5, -0.5)
axes[2].set_aspect('equal')
axes[2].set_title('Greedy Policy', fontsize=13, fontweight='bold')
arrow_dx = [0, 0.35, 0, -0.35]
arrow_dy = [-0.35, 0, 0.35, 0]

for (r, c), q_vals in Q.items():
    if (r, c) == env.goal:
        axes[2].plot(c, r, 's', markersize=25, color='gold')
        axes[2].text(c, r, 'G', ha='center', va='center', fontsize=12, fontweight='bold')
    elif (r, c) in env.obstacles:
        axes[2].add_patch(plt.Rectangle((c-0.5, r-0.5), 1, 1, fill=True, color='gray'))
    else:
        best_a = max(q_vals, key=q_vals.get)
        axes[2].annotate('', xy=(c + arrow_dx[best_a], r + arrow_dy[best_a]),
                         xytext=(c, r),
                         arrowprops=dict(arrowstyle='->', color='#2171b5', lw=2.5))

for ax in axes:
    ax.set_xticks(range(env.cols))
    ax.set_yticks(range(env.rows))
    ax.grid(True, linewidth=1, color='lightgray')

plt.suptitle('Value Functions and Q-Learning -- Notebook 1 Summary', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("Congratulations! You have built value functions from scratch!")
print("You now understand V(s), Q(s,a), and how policies shape values.")

## 9. Reflection and Next Steps

### Reflection Questions
1. Why does the random policy produce lower values than the directed policy? What does this tell us about the relationship between policy quality and value?
2. If gamma = 0, what would the value function look like? What kind of agent would this produce?
3. Can you think of a real-world scenario where you would want a low discount factor (gamma close to 0)?

### Optional Challenges
1. Implement a 10x10 grid world with multiple goals and obstacles. Visualize the value landscape.
2. Try a stochastic environment where the agent slips (moves in a random direction 20% of the time). How does this change the value function?
3. Implement value iteration (instead of policy evaluation) to find the optimal value function directly.