In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns

# Environment (GridWorld) class

* Number of states is based on width and height of the grid of the environment (n_states = width * height)
* Every state has 4 actions; move west, east, north or south
* Goal is to find the exit (one of the terminal states)
* Reward is -1 for each step; this encourages the agent to find the exit as soon as possible
* The discount factor we'll set to 1; meaning a farsighted agent
* If an action takes the agent off the grid, then the agent is put back into the same state it executed that action from

The state-values will be initialized to zero and the agents' policy will be initialized to a random policy meaning each action has equal probability.

In [None]:
class GridWorld:
    """Creates a grid of size width x height. terminal_states is
    a list of integers that indicate what the terminal states are. 
    Each step in the gridworld, yields a reward of -1; this should
    give incentive to the agent to find the exit as fast as possible.
    """
    
    def __init__(self, width, height, terminal_states):
        self.width = width
        self.height = height
        self.n_states = self.width * self.height
        self.states = list(range(self.n_states))
        self.v_values = [0] * self.n_states
        self.terminal_states = terminal_states
        self.validate_terminal_states()
    
    def validate_terminal_states(self):
        assert isinstance(self.terminal_states, list), "terminal_states must be a list."
        for state in self.terminal_states:
            assert (state in self.states), "Terminal state {} not in state set.".format(state)
            
    def state_to_coordinate(self, state):
        xloc = state % self.width
        yloc = int(state / self.width)
        return (xloc, yloc)
            
    def coordinate_to_state(self, xloc, yloc):
        return yloc * self.width + xloc
    
    def is_valid_coordinate(self, xloc, yloc):
        if xloc < 0 or xloc >= self.width:
            return False
        if yloc < 0 or yloc >= self.height:
            return False
        return True
    
    def step(self, state, action):
        if action == 0:
            xloc, yloc = self.state_to_coordinate(state)
            xloc -= 1
            if not self.is_valid_coordinate(xloc, yloc):
                xloc += 1  # revert change
        elif action == 1:
            xloc, yloc = self.state_to_coordinate(state)
            xloc += 1
            if not self.is_valid_coordinate(xloc, yloc):
                xloc -= 1  # revert change
        elif action == 2:
            xloc, yloc = self.state_to_coordinate(state)
            yloc -= 1
            if not self.is_valid_coordinate(xloc, yloc):
                yloc += 1  # revert change
        elif action == 3:
            xloc, yloc = self.state_to_coordinate(state)
            yloc += 1
            if not self.is_valid_coordinate(xloc, yloc):
                yloc -= 1  # revert change
        observation = self.coordinate_to_state(xloc, yloc)
        reward, terminal, info = -1, False, dict()
        if observation in self.terminal_states:
            terminal = True
        return observation, reward, terminal, info
    
    def policy_evaluation_sweep(self, policy):
        new_v_values = self.v_values.copy()
        for state in self.states:
            if state in self.terminal_states:
                # state-value of terminal state always stays zero
                continue
            state_policy = policy[state]
            new_v = 0
            for action, pr_action in enumerate(state_policy):
                state_prime, reward, _, _ = self.step(state, action)
                new_v += pr_action * 1. * (reward + self.v_values[state_prime])
            new_v_values[state] = new_v
        self.v_values = new_v_values
        
    def render_v_values(self, ax=None):
        if ax is None:
            fig, ax = plt.subplots(figsize=(self.width, self.height))
        data = np.reshape(self.v_values, (self.height, self.width))
        sns.heatmap(data, cmap='coolwarm', annot=data, fmt='.3g', annot_kws={'fontsize': 14}, cbar=False, ax=ax)

# Agent class

In [None]:
class Agent:
    """Based on what the environment looks like, creates an agent
    that can do four actions: move west, east, north or south. The
    policy is that each of the actions will be selected with equal
    probability.
    """
    
    def __init__(self, width, height, terminal_states):
        self.width = width
        self.height = height
        self.n_states = self.width * self.height
        self.policy = np.ones((self.n_states, 4)) * 0.25
        self.terminal_states = terminal_states
        
    def policy_improvement(self, env):
        for state in range(self.policy.shape[0]):
            if state in self.terminal_states:
                # no policy needed in terminal state
                continue
            state_policy = self.policy[state]
            action_values = np.zeros(len(state_policy))
            for action, pr_action in enumerate(state_policy):
                state_prime, reward, _, _ = env.step(state, action)
                action_values[action] = (1. * (reward + round(env.v_values[state_prime], 5)))  # round for numerical imprecision
            top_actions = np.where(action_values == max(action_values))[0]
            pr_top_actions = 1 / len(top_actions)
            new_state_policy = np.zeros(len(state_policy))
            new_state_policy[top_actions] = pr_top_actions
            self.policy[state, :] = new_state_policy

    def render_policy(self, env, ax=None):
        if ax is None:
            fig, ax = plt.subplots(figsize=(self.width, self.height))

        ax.hlines(range(self.height+1), 0, self.width, color='black', lw=1)
        ax.vlines(range(self.width+1), 0, self.height, color='black', lw=1)
        ax.set_xlim(0, self.width)
        ax.set_ylim(0, self.height)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.invert_yaxis()

        for state in range(self.policy.shape[0]):
            xloc, yloc = env.state_to_coordinate(state)
            if state in self.terminal_states:
                rect = Rectangle((xloc, yloc), width=1, height=1, ec='black', fc='black', alpha=0.3)
                ax.add_patch(rect)
                continue
            xloc += 0.5
            yloc += 0.5
            if self.policy[state, 0]:
                ax.arrow(xloc, yloc, dy=0, dx=-0.3, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 1]:
                ax.arrow(xloc, yloc, dy=0, dx=0.3, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 2]:
                ax.arrow(xloc, yloc, dy=-0.3, dx=0, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 3]:
                ax.arrow(xloc, yloc, dy=0.3, dx=0, head_width=0.09, fc='black', length_includes_head=True)

# Example GridWorld
* width=7
* height=3
* terminal_states=[(0,0), (0,5), (2,7)]

In [None]:
# example GridWorld with example agent
width, height = 7, 3
terminal_states = [0, 5, 20]
example_env = GridWorld(width, height, terminal_states)
example_env.render_v_values()
Agent(width, height, terminal_states).render_policy(example_env)

# Create simple 4x4 GridWorld

In [None]:
def policy_iteration_illustration(env, agent, show_iterations):
    """Runs policy evaluation with respect to the random policy.
    In addition runs policy improvement based on the state-value
    estimates of the random policy.
    """
    
    def plot(iteration, ax1_title=None, ax2_title=None):
        fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(9, 4))
        if ax1_title:
            ax1.text(0.5, 1.2, ax1_title, fontsize=18, ha='center', va='center', transform=ax1.transAxes)
        if ax2_title:
            ax2.text(0.5, 1.2, ax2_title, fontsize=18, ha='center', va='center', transform=ax2.transAxes)
        env.render_v_values(ax1)
        agent.render_policy(env, ax2)
        ax2.text(1.1, 0.5, "$k={}$".format(iteration), fontsize=22, ha='left', va='center', transform=ax2.transAxes)
    
    plot(0, "$v_{k}$ for the random policy", "Greedy policy w.r.t. $v_{k}$")
    random_policy = agent.policy.copy()
    for i in range(1, max(show_iterations)+1):
        env.policy_evaluation_sweep(random_policy)  # update estimate of state-value function with respect to the random policy
        agent.policy_improvement(env)  # improve policy greedily with respect to latest state-value function
        if i in show_iterations:
            plot(i)

In [None]:
width, height = 4, 4
terminal_states = [0, 15]
env = GridWorld(width, height, terminal_states)
agent = Agent(width, height, terminal_states)
policy_iteration_illustration(env, agent, [0,1,2,3,10,200])

So how do we do one step of policy evaluation?

Note the following:
* the policy is the random policy: $\pi(a|s) = 0.25$ for all $a \in \{\text{west, east, north, south}\}$
* remember that $\gamma = 1$
* the environment dynamics are deterministic, meaning that for each action, we know exactly what the next state will be.<br/>
* Recall the Bellman update equation:
$$
v_{k+1}(s) 
= \sum_{a} \pi(a|s) \sum_{s', r} p(s',r | s,a) [r + \gamma v_{k}(s')]
$$


Now let's say we want to update $v_{200}(s)$ in coordinate (2,1), then, uing the above equation, we get:
$$
\begin{equation}
\begin{aligned}
v_{201}((2,1)) = & 
\pi(\text{west}|(2,1)) * \sum_{s', r} p(s',r | (2,1),\text{west}) [r + 1 * v_{200}(s')] \\&
+ \pi(\text{east}|(2,1)) * \sum_{s', r} p(s',r | (2,1),\text{east}) [r + 1 * v_{200}(s')] \\&
+ \pi(\text{north}|(2,1)) * \sum_{s', r} p(s',r | (2,1),\text{north}) [r + 1 * v_{200}(s')] \\&
+ \pi(\text{south}|(2,1)) * \sum_{s', r} p(s',r | (2,1),\text{south}) [r + 1 * v_{200}(s')] \\&
= \pi(\text{west}|(2,1)) * p((2,0),-1 | (2,1),\text{west}) [-1 + 1 * v_{200}((2,0))] \\&
+ \pi(\text{east}|(2,1)) * p((2,2),-1 | (2,1),\text{east}) [-1 + 1 * v_{200}((2,2))] \\&
+ \pi(\text{north}|(2,1)) * p((1,1),-1 | (2,1),\text{north}) [-1 + 1 * v_{200}((1,1))] \\&
+ \pi(\text{south}|(2,1)) * p((3,1),-1 | (2,1),\text{south}) [-1 + 1 * v_{200}((3,1))] \\&
= (0.25 * 1 * (-1 -20)) + (0.25 * 1 * (-1 -18)) + (0.25 * 1 * (-1 -18)) + (0.25 * 1 * (-1 -20)) \\&
= -20
\end{aligned}
\end{equation}
$$

# Bigger 10x10 GridWorld

In [None]:
width, height = 10, 10
terminal_states = [27]
env = GridWorld(width, height, terminal_states)
agent = Agent(width, height, terminal_states)
env.render_v_values()
agent.render_policy(env)

### Policy evaluation for 1000 steps (estimate $v_{\pi_{0}}$) and 1 policy improvement step (estimate $\pi_{1}$)

In [None]:
for i in range(1000):
    env.policy_evaluation_sweep(agent.policy)
env.render_v_values()

In [None]:
agent.policy_improvement(env)
agent.render_policy(env)

### Policy evaluation for 1000 steps with $\pi_{1}$ (estimate $v_{\pi_{1}}$) and again one policy improvement step (estimate $\pi_{2}$)

In [None]:
for i in range(1000):
    env.policy_evaluation_sweep(agent.policy)
env.render_v_values()

In [None]:
agent.policy_improvement(env)
agent.render_policy(env)

Thus after 2 policy evaluation and 2 policy improvement steps, we converged to our final (optimal solution).

$$
\pi_{0} 
\xrightarrow{\text{E}} v_{\pi_{0}}
\xrightarrow{\text{I}} \pi_{1}
\xrightarrow{\text{E}} v_{\pi_{1}}
\xrightarrow{\text{I}} \pi_{2}
$$

where $v_{\pi_{1}} = v_{\pi_{*}}$ and $\pi_{2} = \pi_{*}$.

Officially, we'd have to do one more policy evaluation step to obtain $v_{\pi_{2}}$, but we know and can see that this will be equal to $v_{\pi_{1}}$.