# Example 2 - Kill Shreck

This is a little maze solving example taken from this reinforcement learning problem set [here](https://www.cs.cmu.edu/~mgormley/courses/10601-f21/handouts/exam3_practice_solutions.pdf).

![Maze](../../assets/example_2.png)

## Libraries

In [1]:
import numpy as np

from rich import print

from markov_decision_process import TimeAugmentedMDP

from itertools import product

import seaborn as sns

sns.set_palette("deep")
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Problem Setup

In [2]:
class MDPProblem(TimeAugmentedMDP):
    """
    Help Farquad kill Shreck.
    """

    def __init__(self):
        super().__init__()

        # State space
        self.rows: list[int] = [0, 1, 2, 3]
        self.cols: list[int] = [0, 1, 2, 3]
        self.directions: list[str] = ["N", "E", "S", "W"]
        states: list[tuple[int, int, str]] = list(
            product(self.rows, self.cols, self.directions)
        )

        self.S: list[str] = states

        # Action space
        self.A: list[str] = ["R", "L", "M"]  # Right, Left, Move

        # Times
        # We'll need about 25 time steps.
        self.T: list[int] = list(np.arange(0, 25))

        return None

    def _get_next_element(self, L, s):
        current_index = L.index(s)
        next_index = (current_index + 1) % len(L)
        return L[next_index]

    def _get_previous_element(self, L, s):
        current_index = L.index(s)
        previous_index = (current_index - 1) % len(L)
        return L[previous_index]

    def transition(self, s_prime, s, t, a):
        # Unpack tuples
        row, col, direction = s
        row_prime, col_prime, direction_prime = s_prime

        # Do our rotations first. Rotations require s_prime and s have the same
        # row and column.
        if row == row_prime and col == col_prime:
            if a == "R":
                next_direction = self._get_next_element(
                    self.directions, direction
                )
                return int(direction_prime == next_direction)
            elif a == "L":
                previous_direction = self._get_previous_element(
                    self.directions, direction
                )
                return int(direction_prime == previous_direction)

        if a == "M":
            # Cases where we try to move outside the grid
            if (
                (direction == "N" and row == 0)
                or (direction == "E" and col == 3)
                or (direction == "S" and row == 3)
                or (direction == "W" and col == 0)
            ):
                return 0

            # Cases where we're against a barrier
            if (
                (direction == "E" and col == 1 and row in [0, 1, 2])
                or (direction == "W" and col == 2 and row in [0, 1, 2])
                or (direction == "S" and row == 1 and col == 1)
                or (direction == "N" and row == 2 and col == 1)
                or (direction == "E" and col == 2 and row in [1, 2, 3])
                or (direction == "W" and col == 3 and row in [1, 2, 3])
            ):
                return 0

            # Otherwise, move.
            if direction == "N":
                return row_prime == row - 1 and col_prime == col
            elif direction == "E":
                return row_prime == row and col_prime == col + 1
            elif direction == "S":
                return row_prime == row + 1 and col_prime == col
            elif direction == "W":
                return row_prime == row and col_prime == col - 1

        return 0

    def reward(self, s_prime, s, t, a):
        row, col, direction = s
        row_prime, col_prime, direction_prime = s_prime

        if row_prime == 3 and col_prime == 3:
            return 5
        else:
            return 0

In [3]:
mdp = MDPProblem()
mdp.solve()

INFO:markov_decision_process.time_augmented_mdp:State space augmented with time
INFO:markov_decision_process.time_augmented_mdp:MDP solved




### Solution

For the time 0 case, for each (i,j) find any directions which result in the optimal action being a move.

In [4]:
moves = []
for state, action in mdp.policy_function[0].items():
    if action == "M":
        moves.append(state)

arrow_map = {"N": "↑", "E": "→", "S": "↓", "W": "←"}
moves = [(state[0], state[1], arrow_map[state[2]]) for state in moves]

In [5]:
# thank you claude
def draw_grid(state_tuples, rows=4, cols=4):
    """
    Draw a grid with characters at specified positions.

    Parameters:
    state_tuples (list): List of tuples of the form (row, col, character).
    rows (int): Number of rows in the grid.
    cols (int): Number of columns in the grid.

    Returns:
    str: A string representation of the grid.
    """
    # Initialize grid with empty spaces
    grid = [[" " for _ in range(cols)] for _ in range(rows)]

    # Place characters at specified positions
    for row, col, char in state_tuples:
        if 0 <= row < rows and 0 <= col < cols:
            grid[row][col] = char

    # Convert grid to string representation with border lines
    grid_str = "+" + "---+" * cols + "\n"
    for row in grid:
        grid_str += "|" + "|".join(f" {cell} " for cell in row) + "|\n"
        grid_str += "+" + "---+" * cols + "\n"

    return grid_str


print(draw_grid(moves))

Pretty good! Starting in the top left, the optimal path is to go straight down, straight to the right, up the wall, right, and down.

Technically, this is the set of optimal actions at time 0, but really this problem has a time invariant policy function.

The last cell is empty because it's never optimal to move out of the cell (it keeps giving a payoff of 5 every period), the optimal action is to just keep rotating in that cell.