In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns

# Environment (GridWorld) class

In [None]:
class GridWorld:
    """Creates a grid of size width x height. terminal_states is
    a list of integers that indicate what the terminal states are. 
    Each step in the gridworld, yields a reward of -1; this should
    give incentive to the agent to find the exit as fast as possible.
    """
    
    def __init__(self, width, height, terminal_states):
        self.width = width
        self.height = height
        self.n_states = self.width * self.height
        self.states = list(range(self.n_states))
        self.v_values = [0] * self.n_states
        self.terminal_states = terminal_states
        self.validate_terminal_states()
    
    def validate_terminal_states(self):
        assert isinstance(self.terminal_states, list), "terminal_states must be a list."
        for state in self.terminal_states:
            assert (state in self.states), "Terminal state {} not in state set.".format(state)
            
    def state_to_coordinate(self, state):
        xloc = state % self.width
        yloc = int(state / self.width)
        return (xloc, yloc)
            
    def coordinate_to_state(self, xloc, yloc):
        return yloc * self.width + xloc
    
    def is_valid_coordinate(self, xloc, yloc):
        if xloc < 0 or xloc >= self.width:
            return False
        if yloc < 0 or yloc >= self.height:
            return False
        return True
    
    def step(self, state, action):
        if action == 0:
            xloc, yloc = self.state_to_coordinate(state)
            xloc -= 1
            if not self.is_valid_coordinate(xloc, yloc):
                xloc += 1  # revert change
        elif action == 1:
            xloc, yloc = self.state_to_coordinate(state)
            xloc += 1
            if not self.is_valid_coordinate(xloc, yloc):
                xloc -= 1  # revert change
        elif action == 2:
            xloc, yloc = self.state_to_coordinate(state)
            yloc -= 1
            if not self.is_valid_coordinate(xloc, yloc):
                yloc += 1  # revert change
        elif action == 3:
            xloc, yloc = self.state_to_coordinate(state)
            yloc += 1
            if not self.is_valid_coordinate(xloc, yloc):
                yloc -= 1  # revert change
        observation = self.coordinate_to_state(xloc, yloc)
        reward, terminal, info = -1, False, dict()
        if observation in self.terminal_states:
            terminal = True
        return observation, reward, terminal, info
    
    def policy_evaluation_sweep(self, policy):
        new_v_values = self.v_values.copy()
        for state in self.states:
            if state in self.terminal_states:
                # state-value of terminal state always stays zero
                continue
            state_policy = policy[state]
            new_v = 0
            for action, pr_action in enumerate(state_policy):
                state_prime, reward, _, _ = self.step(state, action)
                new_v += pr_action * 1. * (reward + self.v_values[state_prime])
            new_v_values[state] = new_v
        self.v_values = new_v_values
        
    def render_v_values(self):
        fig, ax = plt.subplots(figsize=(self.width, self.height))
        data = np.reshape(self.v_values, (self.height, self.width))
        sns.heatmap(data, cmap='coolwarm', annot=data, fmt='.3g', annot_kws={'fontsize': 14}, cbar=False, ax=ax)

# Agent class

In [None]:
class Agent:
    """Based on what the environment looks like, creates an agent
    that can do four actions: move west, east, north or south. The
    policy is that each of the actions will be selected with equal
    probability.
    """
    
    def __init__(self, width, height, terminal_states):
        self.width = width
        self.height = height
        self.n_states = self.width * self.height
        self.policy = np.ones((self.n_states, 4)) * 0.25
        self.terminal_states = terminal_states
        
    def policy_improvement(self, env):
        for state in range(self.policy.shape[0]):
            if state in self.terminal_states:
                # no policy needed in terminal state
                continue
            state_policy = self.policy[state]
            state_value = env.v_values[state]
            action_values = np.zeros(len(state_policy))
            for action, pr_action in enumerate(state_policy):
                state_prime, reward, _, _ = env.step(state, action)
                action_values[action] = (1. * (reward + env.v_values[state_prime]))
            top_actions = np.where(action_values == max(action_values))[0]
            pr_top_actions = 1 / len(top_actions)
            new_state_policy = np.zeros(len(state_policy))
            new_state_policy[top_actions] = pr_top_actions
            self.policy[state, :] = new_state_policy

    def render_policy(self):
        fig, ax = plt.subplots(figsize=(width, height))

        ax.hlines(range(height+1), 0, width, color='black', lw=1)
        ax.vlines(range(width+1), 0, height, color='black', lw=1)
        ax.set_xlim(0, width)
        ax.set_ylim(0, height)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.invert_yaxis()

        for state in range(self.policy.shape[0]):
            xloc, yloc = env.state_to_coordinate(state)
            if state in self.terminal_states:
                rect = Rectangle((xloc, yloc), width=1, height=1, ec='black', fc='black', alpha=0.3)
                ax.add_patch(rect)
                continue
            xloc += 0.5
            yloc += 0.5
            if self.policy[state, 0]:
                ax.arrow(xloc, yloc, dy=0, dx=-0.3, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 1]:
                ax.arrow(xloc, yloc, dy=0, dx=0.3, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 2]:
                ax.arrow(xloc, yloc, dy=-0.3, dx=0, head_width=0.09, fc='black', length_includes_head=True)
            if self.policy[state, 3]:
                ax.arrow(xloc, yloc, dy=0.3, dx=0, head_width=0.09, fc='black', length_includes_head=True)

# Create simple 4x4 GridWorld

In [None]:
width, height = 4, 4
terminal_states = [0, 15]
env = GridWorld(width, height, terminal_states)
agent = Agent(width, height, terminal_states)
agent.render_policy()

In [None]:
env.render_v_values()

### policy evaluation

In [None]:
env.policy_evaluation_sweep(agent.policy)
env.render_v_values()

In [None]:
agent.policy_improvement(env)
agent.render_policy()

# Bigger 10x10 GridWorld

In [None]:
width, height = 10, 10
terminal_states = [27]
env = GridWorld(width, height, terminal_states)
agent = Agent(width, height, terminal_states)
agent.render_policy()

In [None]:
env.render_v_values()

### Policy evaluation for 100 steps and 1 policy improvement step

In [None]:
for i in range(100):
    env.policy_evaluation_sweep(agent.policy)
env.render_v_values()

In [None]:
agent.policy_improvement(env)
agent.render_policy()

### Policy evaluation for 100 steps with improved policy and again one policy improvement step

In [None]:
for i in range(100):
    env.policy_evaluation_sweep(agent.policy)
env.render_v_values()

In [None]:
agent.policy_improvement(env)
agent.render_policy()