# GridWorld Environment

In this tutorial, we’ll explore some ideas in reinforcement learning by using a simple example called GridWorld. In this environment, an “agent” (imagine it as a character) learns how to move through a grid to reach certain goals. This setup will help us understand a mathematical framework called a Markov Decision Process (MDP), which underpins many AI systems that make decisions.

## Markov Decision Processes

![title](figures/MDP-model.png)

Before jumping into the math, let’s break down the idea of a Markov Decision Process.
Imagine you’re in a maze. You start somewhere in the maze and can take actions (like moving forward or turning left) that change your position. At each point, you might receive a reward, or a penalty (if, for instance, you'd find a prize or fall into a trap). Over time, you want to learn how to move through the maze in a way that maximizes your rewards. Importantly, the outcome of a given action is only impacted by that given action and the current state when you took the action - it is not impacted by the previous history of actions you have taken. th This process, of making decisions based on where you are and what actions are available, is what a Markov Decision Process describes.
Now let’s look at the formal definition.
MDP Components
An MDP is defined by a tuple: $(S, A, P, R, \gamma)$. A tuple just means that we’re listing a set of related things that describe our situation. Here’s what each part means:
- $S$ (Set of States): Think of states as all the possible situations or places the agent can be in. For GridWorld, each square in the grid is a unique state.
- $A$ (Set of Actions): These are the choices the agent has at each state. For example, in GridWorld, the actions might be moving Up, Down, Left, or Right.
- $P(s'|s,a)$ (Transition Probability): This part tells us the likelihood of ending up in a new state, $s'$, if we’re in state $s$ and take action $a$. For example, if the agent moves “Up” from a particular square, it might reach the square directly above it with 80% certainty, but sometimes it might veer off in another direction.
- $R(s,a,s')$ (Reward): This represents the reward (a positive or negative value) that the agent receives for taking action $a$ in state $s$ and ending up in state $s'$. Rewards are like points: some states might give a big positive reward (like reaching the goal), while others give penalties (like stepping into a pit).
- $\gamma$ (Discount Factor): The discount factor $\gamma$ is a number between 0 and 1 that determines how much importance we give to future rewards. If $\gamma$ is closer to 1, the agent cares more about long-term rewards. If it’s closer to 0, it focuses more on immediate rewards.

## Markov Property
One last piece of the puzzle is the Markov property. This is a special rule that simplifies our calculations: in an MDP, the next state and reward depend only on the current state and action. The history of past states and actions doesn’t matter.
This is why we can represent the MDP with only a current state and action, rather than keeping track of the entire sequence of actions that led us there.

## Applying MDPs to GridWorld

![title](figures/gridworld_sample.png)

To make this clearer, let’s see how these concepts apply to our GridWorld example:
- States ($S$): Each square in the grid is a unique state.
- Actions ($A$): The agent can choose to move Up, Down, Left, or Right.
- Transitions ($P$): There’s a bit of randomness in the agent’s actions. For example, if it tries to go “Up,” it might end up moving “Left” or “Right” by mistake. We control this randomness with a “noise” parameter.
- Rewards ($R$):
    - Reaching the goal square: +10 points
    - Stepping into a pit: -10 points
    - Moving to any other square: -1 point (a small penalty to encourage finding the shortest path)
- $\gamma$: We can adjust this to see how it affects the agent’s choices between immediate rewards and future rewards.

### Setup GridWorld

Run the code cells below to set up the GridWorld

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque

class GridWorld:
    def __init__(self, size=4, seed=None, noise=0.1):
        np.random.seed(seed)

        self.size = size
        self.rng = np.random.RandomState(seed)
        self.grid = np.zeros((size, size), dtype=int)
        self._generate_grid()
        self.start_pos = (0, 0)
        self.agent_pos = self.start_pos
        
        # Stochastic transition parameters
        self.noise = noise  # Probability of taking a random action instead of chosen action
        
        self.step_cost = -1.
        self.terminal_rewards = {
            2: 10.0,    # goal reward
            -1: -10.0   # pit reward
        }
        
        # Action mappings
        self.actions = {
            0: (0, 1),   # right
            1: (1, 0),   # down
            2: (0, -1),  # left
            3: (-1, 0)   # up
        }
        
    def is_terminal(self, state):
        """Check if state is terminal"""
        r, c = state
        return self.grid[r, c] in [2, -1]  # goal or pit
        
    def get_terminal_value(self, state):
        """Get the value for terminal state"""
        r, c = state
        return self.terminal_rewards[self.grid[r, c]]
        
    def _get_reachable_cells(self, start, grid):
        """Get all reachable non-wall, non-terminal cells from start position using BFS"""
        visited = set()
        queue = deque([start])
        visited.add(start)
        
        while queue:
            r, c = queue.popleft()
            for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                new_r, new_c = r + dr, c + dc
                if (0 <= new_r < self.size and 
                    0 <= new_c < self.size and 
                    grid[new_r, new_c] not in [1, 2, -1] and  # Not wall or terminal
                    (new_r, new_c) not in visited):
                    queue.append((new_r, new_c))
                    visited.add((new_r, new_c))
        return visited
        
    def _generate_grid(self):
        """Generate grid ensuring all non-terminal states are reachable"""
        while True:
            # Reset grid
            self.grid.fill(0)
            
            # Place walls first (20% chance)
            for i in range(self.size):
                for j in range(self.size):
                    if (i, j) != (0, 0) and self.rng.random() < 0.2:
                        self.grid[i, j] = 1
            
            # Get all non-wall cells
            non_wall_cells = set((i, j) for i in range(self.size) 
                               for j in range(self.size) 
                               if self.grid[i, j] != 1)
            
            # Check initial reachability
            reachable = self._get_reachable_cells((0, 0), self.grid)
            if len(reachable) != len(non_wall_cells):
                continue
            
            # Remove start position from candidate positions
            non_wall_cells.remove((0, 0))
            candidates = list(non_wall_cells)
            
            if len(candidates) < 2:  # Need at least space for goal and pit
                continue
                
            # Place goal in a reachable cell (preferably far from start)
            distances = [(abs(r-0) + abs(c-0), (r, c)) for r, c in candidates]
            distances.sort(reverse=True)
            goal_pos = distances[0][1]
            self.grid[goal_pos] = 2
            candidates.remove(goal_pos)
            
            # Place pits
            num_pits = max(1, len(candidates) // 10)
            pit_candidates = []
            
            # Try placing each pit and check reachability
            for _ in range(num_pits):
                if not candidates:
                    break
                    
                valid_pit_placed = False
                np.random.shuffle(candidates)  # Randomize order
                
                for pit_pos in candidates[:]:  # Use copy for iteration
                    # Temporarily place pit
                    self.grid[pit_pos] = -1
                    
                    # Get non-terminal, non-wall states
                    non_terminal_cells = set((i, j) for i in range(self.size) 
                                          for j in range(self.size) 
                                          if self.grid[i, j] not in [1, 2, -1])
                    
                    # Check if all non-terminal states are still reachable
                    reachable = self._get_reachable_cells((0, 0), self.grid)
                    
                    if len(reachable) == len(non_terminal_cells):
                        valid_pit_placed = True
                        candidates.remove(pit_pos)
                        pit_candidates.append(pit_pos)
                        break
                    else:
                        # Remove invalid pit placement
                        self.grid[pit_pos] = 0
                
                if not valid_pit_placed:
                    break
            
            # If we placed at least one pit successfully, accept the grid
            if pit_candidates:
                break

    def reset(self):
        self.agent_pos = self.start_pos
        return self.agent_pos

    def get_transition_probs(self, state, action):
        """
        Get transition probabilities for a given state-action pair.
        Returns a list of (next_state, probability) tuples.
        """
        if self.noise == 0.:
            next_state = self._get_next_state(state, action)
            return [(next_state, 1.0)]
            
        transitions = []
        # Intended action
        main_prob = 1.0 - self.noise
        next_state = self._get_next_state(state, action)
        transitions.append((next_state, main_prob))
        
        # Random actions due to noise
        noise_prob = self.noise / 3  # Split noise probability among other actions
        for a in range(4):
            if a != action:
                next_state = self._get_next_state(state, a)
                transitions.append((next_state, noise_prob))
                
        return transitions
    
    def _get_next_state(self, state, action):
        """Helper method to compute next state given current state and action"""
        r, c = state
        dr, dc = self.actions[action]
        new_r = max(0, min(r + dr, self.size - 1))
        new_c = max(0, min(c + dc, self.size - 1))
        
        # If wall, stay in current position
        if self.grid[new_r, new_c] == 1:
            return (r, c)
        return (new_r, new_c)
    
    def step(self, action):
        """
        Take a step in the environment.
        Returns (next_state, reward, done)
        """
        if self.rng.random() < self.noise:
            # With probability noise, take a random action instead
            other_actions = [a for a in range(4) if a != action]
            action = self.rng.choice(other_actions)
        
        next_pos = self._get_next_state(self.agent_pos, action)
        self.agent_pos = next_pos
        
        # Get state type at new position
        cell_type = self.grid[self.agent_pos]
        
        if self.is_terminal(self.agent_pos):
            reward = self.terminal_rewards[cell_type]
            done = True
        else:
            reward = self.step_cost
            done = False
            
        return self.agent_pos, reward, done

    def render(self, ax=None):
        """
        Render the GridWorld using Unicode characters and consistent colors
        
        Args:
            ax: matplotlib axes to render on. If None, creates new figure
        """
        if ax is None:
            plt.figure(figsize=(6, 6))
            ax = plt.gca()
            
        ax.grid(True)
        
        # Define cell styles using Unicode characters
        cell_styles = {
            0: {'color': 'white', 'symbol': None},       # empty cell
            1: {'color': 'gray', 'symbol': '■'},         # wall
            2: {'color': 'lightgreen', 'symbol': '⚑'},   # goal
            -1: {'color': 'pink', 'symbol': '☠'}         # pit
        }
        
        # Draw grid cells
        for i in range(self.size):
            for j in range(self.size):
                cell_type = self.grid[i, j]
                style = cell_styles[cell_type]
                
                # Determine cell color
                if (i, j) == self.start_pos:
                    cell_color = 'lightblue'  # Blue for start position
                else:
                    cell_color = style['color']
                    
                # Fill cell with color
                ax.fill([j, j+1, j+1, j], [i, i, i+1, i+1], cell_color)
        
        # Draw start position
        sr, sc = self.start_pos
        ax.text(sc + 0.5, sr + 0.5, '▶',
            ha='center', va='center', fontsize=20)
        
        ax.set_xlim(0, self.size)
        ax.set_ylim(self.size, 0)
        ax.set_title('GridWorld')
        
        if ax is None:
            plt.show()

In [3]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import traceback

def create_interactive_visualization(env_class, func, theta=1e-9, draw_agent=True):
    """
    Create streamlined interactive visualization for a planning algorithm.
    """
    # Create widgets
    seed_dropdown = widgets.Dropdown(
        options=[('Grid 1', 0), ('Grid 2', 1), ('Grid 3', 21)],
        value=0,
        description='Grid:',
        style={'description_width': 'initial'}
    )
    
    gamma_slider = widgets.FloatSlider(
        value=0.95,
        min=0.5,
        max=0.99,
        step=0.01,
        description='Gamma:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    noise_slider = widgets.FloatSlider(
        value=0.,
        min=0.0,
        max=0.4,
        step=0.05,
        description='Noise:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    output = widgets.Output()
    
    def update_plots(*args):
        with output:
            output.clear_output(wait=True)
            
            try:
                # Create new environment with current parameters
                env = env_class(
                    size=5, 
                    seed=seed_dropdown.value,
                    noise=noise_slider.value
                )
                
                # Create figure with two subplots
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
                
                # First subplot: Original grid
                env.render(ax=ax1)
                
                # Second subplot: Values and policy
                V, policy = func(env, gamma=gamma_slider.value, theta=theta)
                plot_values_and_policy(V, policy, env, 
                                     f"{func.__name__}\nγ={gamma_slider.value:.2f}, noise={noise_slider.value:.2f}",
                                     ax=ax2)
                
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"Error: {str(e)}")
                traceback.print_exc()
    
    # Connect widget changes to update function
    seed_dropdown.observe(update_plots, 'value')
    gamma_slider.observe(update_plots, 'value')
    noise_slider.observe(update_plots, 'value')
    
    # Layout
    controls = widgets.HBox([
        seed_dropdown,
        gamma_slider,
        noise_slider
    ])
    
    # Display widgets and output
    display(controls)
    display(output)
    
    # Initial plot
    update_plots()

def create_qlearning_visualization(env_class, q_learning_func):
    """
    Create interactive visualization for Q-learning that uses an externally defined Q-learning function.
    
    Args:
        env_class: The environment class (GridWorld)
        q_learning_func: Function that implements Q-learning algorithm
            Expected signature: 
            q_learning_func(env, episodes, alpha, gamma, epsilon) -> (Q, rewards, V, policy)
    """
    # Create widgets
    seed_dropdown = widgets.Dropdown(
        options=[('Grid 1', 0), ('Grid 2', 1), ('Grid 3', 21)],
        value=0,
        description='Grid:',
        style={'description_width': 'initial'}
    )
    
    alpha_slider = widgets.FloatSlider(
        value=0.1,
        min=0.01,
        max=0.5,
        step=0.01,
        description='Learning Rate:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    gamma_slider = widgets.FloatSlider(
        value=0.95,
        min=0.5,
        max=0.99,
        step=0.01,
        description='Gamma:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    noise_slider = widgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=0.4,
        step=0.05,
        description='Noise:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    episodes_slider = widgets.IntSlider(
        value=50000,
        min=10000,
        max=100000,
        step=10000,
        description='Episodes:',
        style={'description_width': 'initial'}
    )
    
    epsilon_slider = widgets.FloatSlider(
        value=0.1,
        min=0.01,
        max=0.5,
        step=0.01,
        description='Epsilon:',
        style={'description_width': 'initial'},
        readout_format='.2f'
    )
    
    run_button = widgets.Button(
        description='Run Q-Learning',
        button_style='success'
    )
    
    output = widgets.Output()
    
    def run_qlearning(b):
        """Run Q-learning with current parameters and update plots"""
        with output:
            output.clear_output(wait=True)
            
            try:
                # Create environment
                env = env_class(
                    size=5, 
                    seed=seed_dropdown.value,
                    noise=noise_slider.value
                )
                
                # Run Q-learning using provided function
                Q, rewards, V, policy = q_learning_func(
                    env,
                    episodes=episodes_slider.value,
                    alpha=alpha_slider.value,
                    gamma=gamma_slider.value,
                    epsilon=epsilon_slider.value
                )
                
                # Create plots
                fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))
                
                # Plot original grid
                env.render(ax=ax1)
                
                # Plot training progress
                window_size = 100
                moving_avg = np.convolve(rewards, np.ones(window_size)/window_size, mode='valid')
                ax2.plot(rewards, alpha=0.3, color='blue', label='Episode Rewards')
                ax2.plot(range(window_size-1, len(rewards)), moving_avg, 
                        color='blue', label=f'{window_size}-Episode Moving Average')
                ax2.set_xlabel('Episode')
                ax2.set_ylabel('Reward')
                ax2.legend()
                ax2.set_title('Training Progress')
                
                # Plot learned values and policy with adjusted color scale
                sns.heatmap(V, annot=True, fmt='.1f', cmap='RdYlGn',
                          cbar_kws={'label': 'Value'}, ax=ax3, square=True,
                          mask=V == 0, vmin=-10, vmax=10)
                
                # Add arrows for policy
                for i in range(env.size):
                    for j in range(env.size):
                        if env.grid[i, j] in [1, 2, -1]:  # Skip walls and terminal states
                            continue
                        
                        action = policy[i, j]
                        if action == 0:    # right
                            ax3.arrow(j + 0.5, i + 0.5, 0.3, 0, head_width=0.1, color='black')
                        elif action == 1:  # down
                            ax3.arrow(j + 0.5, i + 0.5, 0, 0.3, head_width=0.1, color='black')
                        elif action == 2:  # left
                            ax3.arrow(j + 0.5, i + 0.5, -0.3, 0, head_width=0.1, color='black')
                        elif action == 3:  # up
                            ax3.arrow(j + 0.5, i + 0.5, 0, -0.3, head_width=0.1, color='black')
                
                ax3.set_title(f"Q-Learning Results\nα={alpha_slider.value:.2f}, γ={gamma_slider.value:.2f}, ε={epsilon_slider.value:.2f}")
                
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"Error: {str(e)}")
                import traceback
                traceback.print_exc()
    
    # Connect run button to function
    run_button.on_click(run_qlearning)
    
    # Layout
    controls = widgets.VBox([
        widgets.HBox([seed_dropdown, alpha_slider, gamma_slider]),
        widgets.HBox([noise_slider, episodes_slider, epsilon_slider]),
        run_button
    ])
    
    # Display widgets and output
    display(controls)
    display(output)


def plot_values_and_policy(V, policy, env, title, ax=None):
    """Plot values and policy arrows on a heatmap."""
    if ax is None:
        plt.figure(figsize=(6, 6))
        ax = plt.gca()
    
    # Plot values as heatmap
    
    sns.heatmap(V, annot=True, fmt='.2f', cmap='RdYlGn',
                cbar_kws={'label': 'Value'}, ax=ax, square=True, mask = V == 0.)
    
    # Add arrows for policy
    for i in range(env.size):
        for j in range(env.size):
            if env.grid[i, j] in [1, 2, -1]:  # Skip walls and terminal states
                continue
            
            action = policy[i, j]
            if action == 0:    # right
                ax.arrow(j + 0.5, i + 0.5, 0.3, 0, head_width=0.1, color='black')
            elif action == 1:  # down
                ax.arrow(j + 0.5, i + 0.5, 0, 0.3, head_width=0.1, color='black')
            elif action == 2:  # left
                ax.arrow(j + 0.5, i + 0.5, -0.3, 0, head_width=0.1, color='black')
            elif action == 3:  # up
                ax.arrow(j + 0.5, i + 0.5, 0, -0.3, head_width=0.1, color='black')
    
    ax.set_title(title)
    ax.set_aspect('equal')
    
    return ax


## Discounted Returns and Value Functions

In reinforcement learning, our goal is to train an **agent** to make decisions that lead to high rewards over time. However, rewards may not all come at once, and they might be spread across future steps. To help balance short-term and long-term rewards, we use something called **discounted returns**. Let’s go over this concept carefully.

### Understanding Discounted Returns

Imagine that you’re earning rewards as you go through different steps in an environment (like moving through a maze). Let’s say each step from time $t$ onward has a reward associated with it: $\{r_t, r_{t+1}, r_{t+2}, ...\}$. But as time passes, these rewards in the future might not be as valuable to us as immediate rewards. This is where the **discount factor**, $\gamma$, comes in.

The **discounted return** starting at a specific time step $t$, denoted $G_t$, is a way of adding up all the future rewards we expect to get, but with each one being worth slightly less as it gets further in the future. The formula for this is:

$ G_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k} $

Here:
- $G_t$ is the **discounted return** starting from time step $t$.
- $\gamma$ is the **discount factor**, a number between 0 and 1.
- $r_{t+k}$ is the reward received $k$ steps after time $t$.

### Why Do We Use a Discount Factor?

The discount factor, $\gamma$, has two important roles:
1. **Prioritizing Immediate Rewards**: If $\gamma$ is close to 1, the agent will consider future rewards almost as important as immediate rewards. If $\gamma$ is closer to 0, the agent will focus more on immediate rewards and almost ignore distant rewards.
2. **Making the Sum Finite**: When we add up rewards into the distant future, $\gamma$ helps ensure that the total doesn’t become infinitely large. For example, if there were rewards at every step and no discounting, the sum could grow indefinitely. By discounting future rewards, we keep this sum finite, even if we consider rewards stretching infinitely into the future.

### Policies and Their Goal

A **policy** is a rule or strategy that tells the agent what action to take in each state. We denote a policy as $\pi$, and it can either be:
- **Deterministic**: Always chooses the same action for a given state.
- **Stochastic**: Randomly chooses actions based on a set of probabilities.

The ultimate goal in reinforcement learning is to find an **optimal policy**, $\pi^*$, that maximizes the expected discounted return $G_t$ from any starting state. In other words, we want to find a way for the agent to consistently choose actions that yield the highest possible cumulative reward over time.

### Value Functions: Quantifying Good Choices

To help find this optimal policy, we introduce two types of functions, called **value functions**, that estimate how good it is to be in a particular state or to take a particular action from that state.

#### 1. State-Value Function ($V^\pi(s)$)

The **state-value function** is a way of estimating the expected return (total future rewards) if the agent starts from a specific state $s$ and follows a particular policy $\pi$. We denote this as $V^\pi(s)$.

Formally, we write:

$ V^\pi(s) = \mathbb{E}_\pi[G_t | S_t = s] $

In words:
- $V^\pi(s)$ is the **expected discounted return** starting from state $s$ and following policy $\pi$.
- This expectation $\mathbb{E}_\pi$ means we’re averaging over the possible outcomes (since the rewards might vary depending on the environment’s randomness and the policy).

When we find the **optimal policy**, $\pi^*$, we get the **optimal state-value function** $V^*(s)$, which tells us the maximum possible return we can expect from each state. This optimal state-value function satisfies the following equation:

$ V^*(s) = \max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^*(s')] $

Breaking it down:
- $P(s'|s,a)$ is the probability of ending up in state $s'$ after taking action $a$ in state $s$.
- $R(s,a,s')$ is the reward received for taking action $a$ in state $s$ and ending up in state $s'$.
- $\gamma V^*(s')$ is the discounted future value, given by the optimal value function for the next state $s'$.

The term $\max_a$ means we’re choosing the action $a$ that maximizes our expected return from state $s$.

#### 2. Action-Value Function ($Q^\pi(s,a)$)

The **action-value function** estimates the expected return if the agent starts in state $s$, takes action $a$, and then follows policy $\pi$. This is written as $Q^\pi(s, a)$.

Formally:

$ Q^\pi(s,a) = \mathbb{E}_\pi[G_t | S_t = s, A_t = a] $

Here, $Q^\pi(s,a)$ represents the **expected discounted return** starting from state $s$, taking action $a$, and then following policy $\pi$.

When we find the optimal policy, we get the **optimal action-value function** $Q^*(s, a)$, which represents the maximum expected return starting from $s$ and taking action $a$. It satisfies the following equation:

$ Q^*(s, a) = \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma \max_{a'} Q^*(s', a')] $

Breaking it down:
- The **sum** is over all possible next states $s'$.
- $P(s'|s,a)$ is the probability of reaching state $s'$ from state $s$ by taking action $a$.
- The term $\max_{a'} Q^*(s', a')$ tells us the maximum expected return from the next state $s'$, assuming we act optimally from there onward.

### Why We Need Iterative Methods

Finding the exact values of these optimal functions, $V^*(s)$ and $Q^*(s, a)$, can be challenging due to:
1. The **stochastic nature of transitions** (the randomness in outcomes).
2. The **recursive structure** of the Bellman equations (where each state’s value depends on the values of future states).
3. The **max operator**, which adds complexity in calculating the best action at each state.

Because of these challenges, we use **iterative methods** to approximate the optimal values and policies:
1. **Value Iteration**: This method directly finds $V^*$ by repeatedly updating estimates of state values.
2. **Policy Iteration**: This approach alternates between evaluating a policy (estimating its value function) and improving the policy based on those estimates.
3. **Q-Learning**: This method learns $Q^*$ by exploring the environment and updating action-value estimates without requiring knowledge of $P$ (transition probabilities) and $R$ (rewards).

## Value Iteration

**Value Iteration** is a method for finding the optimal value function in a Markov Decision Process (MDP) by repeatedly updating value estimates for each state until they converge to the optimal values. This process is based on something called the **Bellman optimality equation**. 

### The Value Iteration Update Rule

Value Iteration begins with an arbitrary initial guess for the value function, often denoted $V_0$. Then, it continuously improves these value estimates by updating each state $s$ according to the following formula:

$ V_{k+1}(s) = \max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V_k(s') \right] $

Let’s break down each component of this update rule:

1. **Immediate Reward ($R(s,a,s')$)**: This term represents the reward we receive immediately when we move from state $s$ to state $s'$ by taking action $a$.

2. **Discounted Future Value ($\gamma V_k(s')$)**: This part takes into account the expected future rewards, starting from the next state $s'$. The discount factor $\gamma$ (where $0 \leq \gamma \leq 1$) makes future rewards slightly less valuable than immediate ones.

3. **Expectation Over Next States ($\sum_{s'} P(s'|s,a)[\cdot]$)**: Since actions can lead to multiple possible next states, we use the probabilities $P(s'|s,a)$ to calculate a weighted average. This accounts for the chance of ending up in each potential next state $s'$ from state $s$ after taking action $a$.

4. **Maximizing Over Actions ($\max_a$)**: Finally, we choose the action $a$ that maximizes the expected reward. This ensures that each state value $V(s)$ represents the maximum achievable return starting from that state, assuming we always act optimally.

### Convergence and the Optimal Policy

Value Iteration continues updating the values of all states until the values converge, meaning the values no longer change significantly between updates. When the difference in values across all states is smaller than a small threshold $\theta$, we consider the values to have **converged**. At this point, we can determine the **optimal policy** $\pi^*$ by choosing the action that gives the highest expected reward for each state $s$:

$\pi^*(s) = \arg\max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V^*(s') \right]$

### Value Iteration Pseudocode

The following pseudocode summarizes the Value Iteration process:

**Input**: MDP $(S, A, P, R)$, discount factor $\gamma$, tolerance $\theta$  
**Output**: Value function $V$, policy $\pi$

1. **Initialize**: Set $V(s) \leftarrow 0$ for all $s \in S$
2. **Repeat**:
    - Set $\Delta \leftarrow 0$ (to track the maximum change in values for this iteration)
    - **For each** state $s \in S$:
        - Store the current value: $v \leftarrow V(s)$
        - Update the value function using the Bellman optimality equation:
          $V(s) \leftarrow \max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V(s') \right]$
        - Update $\Delta$: $\Delta \leftarrow \max(\Delta, |v - V(s)|)$
3. **Until** $\Delta < \theta$ (indicating convergence)
4. **Extract Policy**: For each state, define the policy as:
   $\pi(s) \leftarrow \arg\max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V(s') \right]$

This procedure yields both the optimal value function $V^*$ and the optimal policy $\pi^*$.


### TODO

- Initialize Value (V) and Policy (policy)
- Skip Walls and Terminal States from the calculation
- Compute the expected value in both cases where state is terminal, and when it is not.

In [12]:
def value_iteration(env, gamma=0.99, theta=1e-6):
    """"
    Value Iteration algorithm that uses environment's transition probabilities
    
    Args:
        env: GridWorld environment
        gamma: Discount factor
        theta: Convergence threshold
        
    Returns:
        V: Converged value function
        policy: Optimal policy
    """
    # TODO: Initialize Value 
    V = ...

    # TODO: Initialize policy 
    policy = ...
    
    # Initialize terminal state values
    for i in range(env.size):
        for j in range(env.size):
            if env.is_terminal((i, j)):
                V[i, j] = env.get_terminal_value((i, j))
                policy[i, j] = -1
    
    while True:
        delta = 0
        for i in range(env.size):
            for j in range(env.size):
            
                # TODO: Skip walls and terminal states
                ... 
                
                old_v = V[i, j]
                state = (i, j)
                
                # Try all actions
                action_values = []
                for action in range(4):
                    # Get transition probabilities for this state-action pair
                    transitions = env.get_transition_probs(state, action)
                    
                    # Calculate expected value for this action
                    expected_value = 0
                    for next_state, prob in transitions:
                        if env.is_terminal(next_state):
                            # TODO: Update expected value when state is terminal
                            expected_value ...
                        else:
                            # TODO: Update expected value when state is non terminal
                            expected_value ...
                            
                    action_values.append(expected_value)
                
                # Update value and policy
                V[i, j] = max(action_values)
                policy[i, j] = np.argmax(action_values)
                
                delta = max(delta, abs(old_v - V[i, j]))
        
        if delta < theta:
            break
    
    return V, policy

# Test value iteration
create_interactive_visualization(GridWorld, value_iteration)

HBox(children=(Dropdown(description='Grid:', options=(('Grid 1', 0), ('Grid 2', 1), ('Grid 3', 21)), style=Des…

Output()

## Policy Iteration

**Policy Iteration** is a method for finding the optimal policy in a Markov Decision Process (MDP). It alternates between two steps:
1. **Policy Evaluation**: Calculates the value function for the current policy.
2. **Policy Improvement**: Updates the policy to be greedy with respect to the current value function, meaning it chooses actions that maximize expected returns.

This process repeats until the policy stabilizes, meaning it no longer changes from one iteration to the next, indicating that we have found an optimal policy.

### Policy Evaluation

In the **policy evaluation** step, we calculate the value of each state $s$ under a given policy $\pi$, denoted $V^\pi(s)$. This value function $V^\pi(s)$ represents the expected return starting from state $s$ and following policy $\pi$ forever after. 

To compute $V^\pi(s)$, we solve the **Bellman equation** for each state under the current policy:

$V^\pi(s) = \sum_{s'} P(s'|s, \pi(s)) \left[ R(s, \pi(s), s') + \gamma V^\pi(s') \right]$

Breaking down this equation:
- $P(s'|s, \pi(s))$: The probability of reaching next state $s'$ from state $s$ under the action prescribed by policy $\pi$.
- $R(s, \pi(s), s')$: The immediate reward for transitioning from state $s$ to state $s'$ by taking the action specified by $\pi$.
- $\gamma V^\pi(s')$: The discounted value of the next state $s'$.
- The **sum** over $s'$ takes into account all possible next states, weighted by their transition probabilities.

### Policy Evaluation Pseudocode

The following pseudocode outlines the policy evaluation process:

**Input**: Policy $\pi$, MDP $(S, A, P, R)$, discount factor $\gamma$, tolerance $\theta$  
**Output**: Value function $V^\pi$

1. **Initialize**: Set $V(s) \leftarrow 0$ for all $s \in S$
2. **Repeat**:
    - Set $\Delta \leftarrow 0$ (to track the maximum change in values for this iteration)
    - **For each** state $s \in S$:
        - Store the current value: $v \leftarrow V(s)$
        - Update the value function according to the Bellman equation:
          $V(s) \leftarrow \sum_{s'} P(s'|s, \pi(s)) \left[ R(s, \pi(s), s') + \gamma V(s') \right]$
        - Update $\Delta$: $\Delta \leftarrow \max(\Delta, |v - V(s)|)$
3. **Until** $\Delta < \theta$ (indicating that the values have converged within the specified tolerance)
4. **Return** $V$

This policy evaluation step provides us with the value function $V^\pi$ for the current policy $\pi$. This process will be used in the next step, **Policy Improvement**, to see if we can improve the policy based on the computed value function.


### TODO

- Compute the expected value in both cases where state is terminal, and when it is not.

In [4]:
def policy_evaluation(env, policy, V, gamma=0.99, theta=1e-6):
    """
    Iteratively evaluate a policy using environment's transition probabilities
    """
    while True:
        delta = 0
        for i in range(env.size):
            for j in range(env.size):
                if env.grid[i, j] == 1 or env.is_terminal((i, j)):
                    continue
                    
                old_v = V[i, j]
                state = (i, j)
                action = policy[i, j]
                
                # Get transition probabilities for the current policy
                transitions = env.get_transition_probs(state, action)
                
                # Calculate expected value
                expected_value = 0
                for next_state, prob in transitions:
                    if env.is_terminal(next_state):
                        ## TODO: Update expected value when state is terminal
                        expected_value...
                    else:
                        ## TODO: Update expected value when state is non terminal
                        expected_value...
                
                V[i, j] = expected_value
                delta = max(delta, abs(old_v - V[i, j]))
        
        if delta < theta:
            break
    
    return V

### Policy Improvement

In the **Policy Improvement** step, we take the current value function $V^\pi$ (calculated in the policy evaluation step) and update our policy to make it **greedy** with respect to this value function. This means that, for each state $s$, we select the action that maximizes the expected reward based on $V^\pi$.

The new policy, $\pi'$, is defined by:

$\pi'(s) = \arg\max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V^\pi(s') \right]$

This equation selects, for each state $s$, the action $a$ that yields the highest expected return according to the current value function $V^\pi$. By doing this, the policy $\pi'$ is guaranteed to perform at least as well as $\pi$, and if there is an improvement to be made, $\pi'$ will be strictly better than $\pi$. This result is known as the **Policy Improvement Theorem**.

If the policy no longer changes between iterations (i.e., it becomes stable), then we know that the policy is **optimal**.

### Policy Improvement Pseudocode

The pseudocode below outlines the steps for policy improvement:

**Input**: Value function $V$, MDP $(S, A, P, R)$, discount factor $\gamma$  
**Output**: Improved policy $\pi$, stability flag

1. **Initialize**: Set $\text{policy\_stable} \leftarrow \text{True}$
2. **For each** state $s \in S$:
    - Set $\text{old\_action} \leftarrow \pi(s)$ (store the current action in the policy for state $s$)
    - Update the policy to be greedy with respect to $V$:
      $\pi(s) \leftarrow \arg\max_a \sum_{s'} P(s'|s,a) \left[ R(s,a,s') + \gamma V(s') \right]$
    - **If** the action for $s$ has changed ($\text{old\_action} \neq \pi(s)$):
        - Set $\text{policy\_stable} \leftarrow \text{False}$ (indicating that further policy improvement is possible)
3. **Return** the updated policy $\pi$ and the stability flag $\text{policy\_stable}$

If $\text{policy\_stable}$ is `True` after an iteration, this means that no further changes were made to the policy, and we have found the optimal policy $\pi^*$.


### TODO

- Compute the expected value in both cases where state is terminal, and when it is not.
- Update Policy 

In [6]:
def policy_improvement(env, V, policy, gamma=0.99):
    """
    Make policy greedy with respect to current value function using transition probabilities
    """
    policy_stable = True
    
    for i in range(env.size):
        for j in range(env.size):
            if env.grid[i, j] == 1 or env.is_terminal((i, j)):
                continue
            
            old_action = policy[i, j]
            state = (i, j)
            
            # Try all actions
            action_values = []
            for action in range(4):
                # Get transition probabilities
                transitions = env.get_transition_probs(state, action)
                
                # Calculate expected value
                expected_value = 0
                for next_state, prob in transitions:
                    if env.is_terminal(next_state):
                        expected_value += prob * (env.step_cost + env.get_terminal_value(next_state))
                    else:
                        expected_value += prob * (env.step_cost + gamma * V[next_state])
                
                action_values.append(expected_value)
            
            # TODO: Update policy
            policy[i, j]...
            
            if old_action != policy[i, j]:
                policy_stable = False
    
    return policy, policy_stable


## Complete Policy Iteration

By combining **Policy Evaluation** and **Policy Improvement**, we create the full **Policy Iteration** algorithm. Policy iteration alternates between these two steps until the policy becomes stable, meaning that no further improvements can be made and an optimal policy is found. 

The algorithm works as follows:
1. **Policy Evaluation**: Given the current policy $\pi$, compute the value function $V^\pi$.
2. **Policy Improvement**: Update the policy to make it greedy with respect to $V^\pi$. If the policy does not change during this step, we have found an optimal policy.

### Policy Iteration Pseudocode

The pseudocode below summarizes the steps in policy iteration.

**Input**: MDP $(S, A, P, R)$, discount factor $\gamma$, tolerance $\theta$  
**Output**: Optimal value function $V^*$, optimal policy $\pi^*$

1. **Initialize**: Set $\pi(s)$ arbitrarily for all $s \in S$.
2. **Repeat**:
    - **Policy Evaluation**: Compute the value function $V$ for the current policy $\pi$ using $\text{PolicyEvaluation}(\pi, S, A, P, R, \gamma, \theta)$.
    - **Policy Improvement**: Update the policy to be greedy with respect to $V$:
      $\pi, \text{policy\_stable} \leftarrow \text{PolicyImprovement}(V, S, A, P, R, \gamma)$
3. **Until** $\text{policy\_stable}$ is `True` (indicating that the policy has converged to an optimal policy).
4. **Return** the optimal value function $V^*$ and the optimal policy $\pi^*$.

In each iteration, policy iteration improves both the value function and the policy. When the policy no longer changes between iterations, it has reached the optimal policy $\pi^*$, and $V$ becomes the optimal value function $V^*$.


### To Do 

- Compute Policy Evaluation
- Compute Policy Improvement

In [7]:
def policy_iteration(env, gamma=0.99, theta=1e-6):
    """
    Policy Iteration algorithm using environment's transition probabilities
    """
    # Initialize values and policy
    V = np.zeros((env.size, env.size))
    policy = np.zeros((env.size, env.size), dtype=int)
    
    # Initialize terminal states
    for i in range(env.size):
        for j in range(env.size):
            if env.is_terminal((i, j)):
                V[i, j] = env.get_terminal_value((i, j))
                policy[i, j] = -1
    
    while True:
        # TODO: Compute policy evaluation
        ... 
        
        # TODO: Compute Policy improvement
        ...

        if policy_stable:
            break
    
    return V, policy

# Test policy iteration
create_interactive_visualization(GridWorld, policy_iteration)

HBox(children=(Dropdown(description='Grid:', options=(('Grid 1', 0), ('Grid 2', 1), ('Grid 3', 21)), style=Des…

Output()

## Limitations of Model-based Methods

Both **Value Iteration** and **Policy Iteration** are model-based methods, meaning they rely on complete knowledge of the environment's dynamics to find the optimal policy. Specifically, these methods require:
1. **Transition probabilities** $P(s'|s, a)$: The probability of moving to state $s'$ when taking action $a$ in state $s$.
2. **Reward function** $R(s, a, s')$: The immediate reward received after transitioning from state $s$ to state $s'$ by taking action $a$.

However, in many real-world scenarios:
- The **environment dynamics** (the transition probabilities and reward function) are **unknown** or too complex to model precisely.
- The **state space** (all possible states) is too large to store or calculate all transition probabilities, making it impractical to apply these model-based methods.

These limitations create challenges for model-based methods, as they become impractical or infeasible when we lack a complete model of the environment or face large and complex state spaces.

## The Need for Model-free Learning

The limitations of model-based methods motivate the development of **model-free methods**, which can learn optimal policies directly from experience, without needing prior knowledge of the environment's transition probabilities and reward function. Instead of computing expected values with known transitions, as in:

$Q^*(s, a) = \sum_{s'} P(s'|s, a) \left[ R(s, a, s') + \gamma \max_{a'} Q^*(s', a') \right]$

we can estimate **Q-values** using samples of experience collected by interacting with the environment. Each sample of experience consists of a tuple $(s, a, r, s')$, where:
- $s$ is the current state,
- $a$ is the action taken,
- $r$ is the reward received,
- $s'$ is the next state.

Using these samples, model-free methods can update Q-value estimates based on the actual transitions experienced, rather than relying on predefined probabilities.

This approach leads us to **Q-learning**, a popular model-free algorithm that:
1. **Learns** directly from interaction with the environment, using observed experiences instead of a predefined model.
2. **Updates Q-value estimates** by adjusting them based on observed rewards and transitions.
3. **Converges to optimal Q-values** over time, meaning it can find an optimal policy without needing any prior knowledge of the environment's dynamics.

In this way, model-free methods, such as Q-learning, provide a flexible and powerful approach to reinforcement learning, allowing agents to learn optimal behaviors in environments where model-based methods would be impractical or impossible to use.


## Q-Learning 

**Q-learning** is a model-free reinforcement learning algorithm that learns the optimal action-value function, denoted by $Q^*$, directly from experience. Unlike model-based approaches, Q-learning does not need information about transition probabilities or rewards beforehand; it learns purely from interactions with the environment.

### Algorithm Description

Q-learning maintains an action-value table, $Q(s, a)$, which stores estimates of the expected return (cumulative reward) for each state-action pair. The values in this table are updated using experience tuples $(s, a, r, s')$, where:
- $s$ is the current state,
- $a$ is the action taken in state $s$,
- $r$ is the reward received after taking action $a$,
- $s'$ is the resulting state after taking action $a$.

The **Q-learning update rule** for each experience tuple is:

$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]$

where:
- $\alpha$ is the **learning rate**, which controls how much of the new information (target value) we consider in each update.
- $\gamma$ is the **discount factor**, which determines the importance of future rewards.
- $r + \gamma \max_{a'} Q(s', a')$ is the **target value**—it represents the expected return for taking action $a$ in state $s$ and then following the best actions from $s'$ onward.
- $r + \gamma \max_{a'} Q(s', a') - Q(s, a)$ is known as the **temporal difference error**—it measures the difference between the current estimate and the observed outcome.

**Implementation Note**: This update rule can be algebraically rearranged to an equivalent form that's often used in code:

$Q(s, a) \leftarrow (1 - \alpha)Q(s, a) + \alpha(r + \gamma \max_{a'} Q(s', a'))$

Both forms are mathematically identical. The second form can be derived from the first by factoring out $Q(s, a)$ and combining like terms. In practice, the second form is often preferred for implementation as it more directly expresses how to combine the old Q-value with the new target value using the learning rate $\alpha$.

### Action Selection: $\epsilon$-greedy Policy

To ensure the agent explores different actions and states while learning, Q-learning typically uses an **$\epsilon$-greedy policy**:
- With probability $1 - \epsilon$, the agent **exploits** its current knowledge by selecting the action $a = \arg\max_a Q(s, a)$ (i.e., the action with the highest Q-value in state $s$).
- With probability $\epsilon$, the agent **explores** by choosing a random action, allowing it to gather new information.

The exploration rate $\epsilon$ can start high and decrease over time, allowing the agent to explore initially and then gradually focus on exploitation as it learns the optimal policy.

### Q-Learning Pseudocode

The following pseudocode summarizes the steps in the Q-learning algorithm.

**Input**: Learning rate $\alpha$, discount factor $\gamma$, exploration rate $\epsilon$  
**Output**: Q-function $Q$, policy $\pi$

1. **Initialize** $Q(s, a) \leftarrow 0$ for all $s \in S$, $a \in A$
2. **For each episode**:
    - Initialize the starting state $s$
    - **While** $s$ is not a terminal state:
        - **If** `random() < ε`:
            - Select a random action $a$
        - **Else**:
            - Select action $a \leftarrow \arg\max\limits_{a'} Q(s, a')$
        - Take action $a$, observe reward $r$ and next state $s'$
        - Update Q-value:
          $Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]$
        - Set $s \leftarrow s'$
3. **Extract policy**: $\pi(s) \leftarrow \arg\max\limits_a Q(s, a)$ for all $s \in S$

After enough episodes, the Q-values converge, meaning $Q(s, a)$ approximates the optimal action-value function $Q^*(s, a)$, and the extracted policy $\pi(s)$ becomes the optimal policy for the environment.


### TODO

- Implement the epsilon greedy action selection
- Implement Q-Learning Update Rule

In [8]:
from tqdm.notebook import tqdm

def epsilon_greedy_policy(Q, state, epsilon):
    """Epsilon-greedy policy for action selection"""
    if np.random.random() < epsilon:
        return np.random.randint(4)
    return np.argmax(Q[state[0], state[1]])

def q_learning(env, episodes=1000, alpha=0.1, gamma=0.99, epsilon=0.1):
    """
    Q-Learning algorithm with fixed exploration rate.
    
    Args:
        env: GridWorld environment
        episodes: Number of training episodes
        alpha: Learning rate
        gamma: Discount factor
        epsilon: Fixed exploration rate
    """
    # Initialize Q-values
    Q = np.zeros((env.size, env.size, 4))
    
    # Initialize walls with -inf
    for i in range(env.size):
        for j in range(env.size):
            if env.grid[i, j] == 1:
                Q[i, j].fill(-np.inf)
    
    # Training metrics
    episode_rewards = []
    
    # Training loop with progress bar
    for episode in tqdm(range(episodes), desc='Training Progress'):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:


            # TODO: Epsilon-greedy action selection
            if ...
                ...
            else ...
                ...
            
            # Take action
            next_state, reward, done = env.step(action)
            total_reward += reward
            
            # Q-learning update
            if done:
                target = reward
            else:
                next_max_q = np.max(Q[next_state[0], next_state[1]])
                target = reward + gamma * next_max_q
                
            # TODO: Update Q-value with learning rate alpha
            Q[state[0], state[1], action] = ...
            
            state = next_state
        
        # Store metrics
        episode_rewards.append(total_reward)
    
    # Extract final policy and value function
    policy = np.argmax(Q, axis=2)
    V = np.max(Q, axis=2)
    
    # Set terminal state policies and values
    for i in range(env.size):
        for j in range(env.size):
            if env.grid[i, j] in [2, -1]:  # Goal or pit
                V[i, j] = env.get_terminal_value((i, j))
                policy[i, j] = -1
            elif env.grid[i, j] == 1:  # Wall
                policy[i, j] = -1
    
    return Q, episode_rewards, V, policy

### Visualize Results

In [9]:
create_qlearning_visualization(GridWorld, q_learning)

VBox(children=(HBox(children=(Dropdown(description='Grid:', options=(('Grid 1', 0), ('Grid 2', 1), ('Grid 3', …

Output()