In [1]:
import numpy as np
import random
from enum import IntEnum
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [2]:
class Action(IntEnum):
    up = 0
    right = 1
    down = 2
    left = 3

action_to_str = {
    Action.up: 'up',
    Action.right: 'right',
    Action.down: 'down',
    Action.left: 'left',   
}

action_to_offset = {
    Action.up: (-1, 0),
    Action.right: (0, 1),
    Action.down: (1, 0),
    Action.left: (0, -1),
}


In [8]:
class GridWorld:

    def __init__(self, height, width, goal, goal_value=5.0, danger=[], danger_value=-5.0, blocked=[]):
        self._width = width
        self._height = height
        self._grid_values = [0 for _ in range(height * width)] # Initialize state values.
        self._goal_value = goal_value
        self._danger_value = danger_value
        self._goal_cell = goal
        self._danger_cells = danger
        self._blocked_cells = blocked
        self.calculate_next_value() # Initialize the next state values.
        self.start_state = self.state_from_pos((0, 0))
        self.reset()

    def state_from_pos(self, pos):
        # covert a row and col pair into a state number as given in the gridworld instructions
        col, row = pos
        # maps each pair to a specific number that represents a position in the gridworld
        return (self._height - 1 - row) * self._width + (self._width - 1 - col)

    def pos_from_state(self, state):
        # convert the position in the grid world back to a row and col pair
        row = self._height - 1 - state // self._width
        col = self._width - 1 - state % self._width
        return (col, row)
    
    def calculate_next_value(self):
       # make sure that the next value is within range of the gridworld
       # if it is not, then the next value is the same as the current value
        self.next_value = {}
        for i in range(self._height * self._width):
           # if the next state is a blocked state then return back to the current state
           if i in self._blocked_cells:
                for j in Action:
                    self.next_value[(i, j)] = i 
        
        pos = self.pos_from_state(i)
        for j in Action:
            offset = action_to_offset[j]
            next_pos = (pos[0] + offset[0], pos[1] + offset[1])
            if next_pos[0] < 0 or next_pos[0] >= self._width or next_pos[1] < 0 or next_pos[1] >= self._height:
                next_state = i
            else:
                next_state = self.state_from_pos(next_pos)
                if next_state in self._blocked_cells:
                    next_state = i
            self.next_value[(i, j)] = next_state
    
    def reset(self):
        self.current_state = self.start_state
        self.steps = 0
        return self.current_state

    def step(self, action):
        # execute the action and return the next state, reward, and whether the episode is done
        next_state = self.next_state.get((self.current_state, action), self.current_state)
        self.steps += 1
        self.current_state = next_state

        reward = -0.1
        if self.current_state == self._goal_cell:
            reward = self._goal_value
        elif self.current_state in self._danger_cells:
            reward = self._danger_value
        
        finished = (self.current_state == self._goal_cell
                    or self.current_state in self._danger_cells
                    or self.steps >= 30)
        return self.current_state, reward, finished
    
    def render(self):
        #create the gridworld in terminal
        gridworld = [["" for _ in range(self._width)] for _ in range(self._height)]
        for i in range(self._height * self._width):
            pos = self.pos_from_state(i)
            symbol = "."
            if i == self._goal_cell:
                symbol = "+"
            elif i in self._danger_cells:
                symbol = "-"
            elif i in self._blocked_cells:
                symbol = "#"
            gridworld[pos[1]][pos[0]] = symbol
        
        agent_pos = self.pos_from_state(self.current_state)
        gridworld[agent_pos[1]][agent_pos[0]] = "X"

        for row in reversed(gridworld):
            print(" ".join(row))
        
        print()
    

In [17]:
if __name__ == "__main__":
    # For a 3x3 grid with our state mapping:
    # Let the goal be state 0, danger be state 4, and blocked (wall) be state 5.
    env = GridWorld(height=3, width=4, goal=3, goal_value=1.0, danger=[4], danger_value=-1.0, blocked=[1])
    
    print("Initial GridWorld:")
    env.render()

Initial GridWorld:
+ . # .
. . . -
X . . .

