----

# Dynamic Programming 
Given complete knowledge of a model of the environment as a finite Markov decision problem, we may use algorithms from dynamic programming to compute optimal policies. More specifically, assuming we know the dynamics of the MDP, in this notebook we explore the following algorithms:

* Iterative Policy Evalution, for estimating $V \approx v_{\pi}$
* Policy Improvment, for estimating $\pi \approx \pi_*$

The environment for which we will study these algorithms will be the *grid world* environment with deterministic dynamics. In the following code cell we define our `GridWorld` Python `class` and a helper `Actions` Python `class` object. 


---

In [None]:
from enum import Enum
import numpy as np

class Actions(Enum):
    """
    Enumeration of possible actions in the GridWorld environment.

    Attributes:
        UP (int): Represents the action of moving up.
        RIGHT (int): Represents the action of moving right.
        DOWN (int): Represents the action of moving down.
        LEFT (int): Represents the action of moving left.
    """
    UP = 0
    RIGHT = 1
    DOWN = 2
    LEFT = 3

class GridWorld:
    """
    A simple grid world environment for reinforcement learning.

    Attributes:
        rows (int): The number of rows in the grid (rows x cols).
        cols (int): The number of cols in the grid (rows x cols).

    Methods:
        is_terminal(state): Checks if a given state is terminal.
        step(state, action): Takes a step in the environment given a state and
        an action.
        render(policy=None, agent=None): Renders the grid world environment. If
        a policy is given, the policy will be displayed on the grid. If an agent
        is given, the agent's position will be displayed on the grid.
    """

    def __init__(self, rows=4, cols=4):
        """
        Initializes the GridWorld environment.

        Parameters:
            size (int): The size of the grid world (default is 4x4).
        """
        self.rows = rows
        self.cols = cols
        self.move = {
            Actions.LEFT: np.array([0, -1]),
            Actions.RIGHT: np.array([0, 1]),
            Actions.UP: np.array([-1, 0]),
            Actions.DOWN: np.array([1, 0]),
        }
        self.state_space = [(i, j) for i in range(self.rows) for j in range(self.cols)]

    def is_terminal(self, state):
        """
        Checks if the given state is a terminal state.

        Parameters:
            state (tuple): The state to check.

        Returns:
            bool: True if the state is terminal, False otherwise.
        """
        x, y = state
        return (x == 0 and y == 0) or (x == self.rows - 1 and y == self.cols - 1)

    def step(self, state, action):
        """
        Performs an action in the environment from a given state.

        Parameters:
            state (tuple): The current state.
            action (Actions): The action to be performed.

        Returns:
            tuple: The next state and the reward received.
        """
        state = np.array(state)
        next_state = (state + self.move[action]).tolist()
        x, y = next_state

        # Check for boundaries of the grid.
        if x < 0 or x >= self.rows or y < 0 or y >= self.cols:
            next_state = state.tolist()

        reward = -1
        done = self.is_terminal(next_state)
        return tuple(next_state), reward, done

    # Render the grid for visualizations of the agents progress.
    def render(self, policy=None, agent=None):
        """
        Renders the grid world environment.

        Parameters:
            policy (function): A function that takes a state and returns an
            action. If given, the policy will be displayed on the grid.
            agent (tuple): The position of the agent. If given, the agent's
            position will be displayed on the grid.

        Returns:
            None
        """
        # Top border of the grid
        print('+' + '---+' * self.cols)

        for i in range(self.rows):
            print('|', end='')
            for j in range(self.cols):
                cell = ' '
                if agent and (i, j) == agent:
                    # Agent's position.
                    cell = 'A'
                elif self.is_terminal((i, j)):
                    # Hole in red.
                    cell = '\033[91mT\033[0m'
                elif policy:
                    # Display policy direction.
                    action = policy((i, j))
                    if action == Actions.UP:
                        cell = '↑'
                    elif action == Actions.RIGHT:
                        cell = '→'
                    elif action == Actions.DOWN:
                        cell = '↓'
                    elif action == Actions.LEFT:
                        cell = '←'

                print(f' {cell} |', end='')
            # Bottom border of each row.
            print('\n+' + '---+' * self.cols)

----

Now that we have created our action and environment class, it is useful to explore the methods of the classes. Our environment class should be defined in such a way that we observe a state, take an action, and recieve numerical reward, and a next state. Start be instantiating the `GridWorld` environment, definining the agent state, and rendering the agent in the environment in the next code cell. 

----

In [None]:
# Instantiate an instance of the GridWorld environment.
env = GridWorld()

# Example starting state.
state = (0, 1)

# Show the agent in the grid.
env.render(agent=(0,1))

----

Next choose an arbitrary action, say `action = Actions.RIGHT`, and ask the environment to take this action given the previous state. The environment should return to you a triple, namely, `(next_state, reward, done)`. As you can probably expect, `next_state` is the resulting state from taking `action` in the variable `state`; `done` is a boolean value indicating whether or not the agent has landed in a *terminal state*. 

----

In [None]:
# Move to the right.
action = Actions.RIGHT

# Step in the environment.
next_state, reward, done = env.step(state, Actions.RIGHT)

# Print the state, action, reward, next state tuple.
print(f"{state = }, {action = }, {reward = }, {next_state = }, {done = }")

env.render(agent=next_state)

----

In order to understand how we can solve MDP's with dynamic programming, we will first focus on random equaprobability policies. In the following code cell we write functions representing this notion. 

----

In [None]:
# pi(s) = s'
def random_policy(state):
    return np.random.choice(list(Actions))

# pi(a | s)
# For our random policy, all actions are assumed to have the same probability
# given state s.
def random_policy_probability(state, action):
    return 1/4

# Show the actions made by the policy in the environment.
env.render(policy=random_policy)

----

## Iterative Policy Evaluation, for evaluting $V \approx v_{\pi}$

In reinforcement learning, we seek good policies. Our goal here will be to use value functions to organize and structure the search for good policies. Before doing so, first recall the **Bellman optimality equations**:

State Value Optimality: $v_*(s) = \max_{a}\sum_{s', r}p(s', r | s, a)[r + \gamma v_*(s')]$

State-Action Value Optimality: $q_*(s, a) = \sum_{s', r}p(s', r | s, a)[r + \gamma \max_{a'}q_*(s', a')]$




----

In [None]:
def iterative_policy_evaluation(
        env,
        policy_probability,
        policy=None,
        theta=1e-4,
        show_iterations=False
    ):
    """
    Iteratively evaluates a random policy in the given GridWorld environment.

    Parameters:
        env (GridWorld): The grid world environment.
        theta (float): A threshold for the evaluation accuracy.

    Returns:
        tuple: The final value function and the number of iterations.
    """
    # Initialize value function to zeros.
    V = np.zeros((env.rows, env.cols))

    # Copy value function for synchronous updating of the values in V.
    new_v = V.copy()

    # Iteration count for tracking progress.
    iteration = 0

    while True:
        for state in env.state_space:
            # Ignore the terminal states.
            if env.is_terminal(state):
                continue
            else:
                i, j = state
            # Initialize value to sum onto over all actions given the state.
            value = 0
            for action in Actions:
                # Calculate the value for each action.
                next_state, reward, _ = env.step(state, action)
                new_i, new_j = next_state
                value += policy_probability(state, action) * (reward + V[new_i, new_j])

            new_v[i, j] = value

        # Check for convergence.
        if np.sum(np.abs(new_v - V)) < theta:
            V = new_v.copy()
            break

        V = new_v.copy()
        iteration += 1

    if show_iterations:
        print(f"The number of iterations = {iteration}")
    return V

In [None]:
V = iterative_policy_evaluation(env, random_policy_probability)
print(f"{V = }")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.table import Table
import matplotlib

# Use the 'Agg' backend for matplotlib to avoid the need for a GUI.
matplotlib.use('Agg')


def draw_image(env, values, fig_name='value-states'):
    """
    Draws the grid world with values in each state.

    Parameters:
        values (np.array): Array containing the values for each state.
        fig_name (str): The filename for saving the figure.
    """
    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])
    values = np.round(values, decimals=2)

    nrows, ncols = env.rows, env.cols
    width, height = 1.0 / ncols, 1.0 / nrows

    # Adding cells to the table.
    for (i, j), val in np.ndenumerate(values):
        color = 'white'
        tb.add_cell(i, j, width, height, text=val,
                    loc='center', facecolor=color)

    # Adding row and column labels.
    for i in range(len(values)):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
    for j in range(len(values)):
        tb.add_cell(-1, j, width, height/2, text=j+1, loc='center',
                    edgecolor='none', facecolor='none')

    ax.add_table(tb)
    plt.savefig(f'{fig_name}.png')
    plt.close()

In [None]:
# Draw the value of states in a grid. This will be saved in the folder during
# the colab session.
draw_image(env, V)

In [None]:
# Define a better policy based off of the values of states found by the value
# iteration algorithm on the random policy.
def better_policy(env, state, value_grid):
    best_action = Actions.UP
    next_state, reward, _ = env.step(state, best_action)
    i, j = next_state
    max_value = value_grid[i, j]
    for action in [Actions.DOWN, Actions.LEFT, Actions.RIGHT]:
        next_state, reward, _ = env.step(state, action)
        i, j = next_state
        if max_value < value_grid[i, j]:
            best_action = action
            max_value = value_grid[i, j]
    return best_action

# Specific instance of this.
def pi(state):
    return better_policy(env, state, V)

In [None]:
# View the new policy action for each state in the environment (excluding the
# terminal states.
for state in env.state_space:
    if env.is_terminal(state):
        continue
    action = pi(state)
    print(f"{state=}, {action=} \n")

print(f"{V=}")

In [None]:
env.render(policy=pi)

In [None]:
def simulate_gridworld_walk(env, starting_state, policy, value_table):
    states = [starting_state]
    actions = []
    rewards = []
    time = []
    for i in range(10):
        action = policy(env, states[-1], value_table)
        next_state, reward, done = env.step(states[-1], action)
        time.append(i)
        states.append(next_state)
        actions.append(action)
        rewards.append(reward)
        if done:
            break

    for t, state, action, reward in zip(time, states, actions, rewards):
        print(f"----------------------------------")
        print(f"Time: {t = } ")
        print(f"S_{t} = {state}, A_{t} = {action.name}, R_{t+1} = {reward} \n")
        env.render(agent=tuple(state))
        print()
    print(f"----------------------------------")
    print("Terminal State Reached \n")
    env.render(agent=states[-1])

In [None]:
# Simulate an agent following the better policy in the GridWorld environment.
simulate_gridworld_walk(env, (2, 1), better_policy, V)