In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import random

In [2]:
class GridWorld:
    """Grid world environment for Q-Learning"""

    def __init__(self, size=5):
        self.size = size
        self.start = (0, 0)
        self.goal = (size-1, size-1)
        self.obstacles = [(1, 1), (2, 2), (3, 1)]
        self.state = self.start

    def reset(self):
        """Reset environment to starting state"""
        self.state = self.start
        return self.state

In [3]:
def step(self, action):
    """Execute action and return next state, reward, done"""
    x, y = self.state

    # Actions: 0=up, 1=right, 2=down, 3=left
    if action == 0:  # up
        x = max(0, x - 1)
    elif action == 1:  # right
        y = min(self.size - 1, y + 1)
    elif action == 2:  # down
        x = min(self.size - 1, x + 1)
    elif action == 3:  # left
        y = max(0, y - 1)

    next_state = (x, y)

    # Check if hit obstacle
    if next_state in self.obstacles:
        next_state = self.state  # Stay in place
        reward = -10
    elif next_state == self.goal:
        reward = 100
    else:
        reward = -1  # Small penalty for each step

    self.state = next_state
    done = (next_state == self.goal)

    return next_state, reward, done

# Add this method to GridWorld class
GridWorld.step = step

In [4]:
def render(self, q_table=None, policy=None):
    """Visualize the grid world"""
    grid = np.zeros((self.size, self.size))

    # Mark obstacles
    for obs in self.obstacles:
        grid[obs] = -1

    # Mark goal
    grid[self.goal] = 2

    # Mark current position
    grid[self.state] = 1

    plt.figure(figsize=(8, 8))

    # Create color map
    cmap = colors.ListedColormap(['white', 'blue', 'green', 'red'])
    bounds = [-1.5, -0.5, 0.5, 1.5, 2.5]
    norm = colors.BoundaryNorm(bounds, cmap.N)

    plt.imshow(grid, cmap=cmap, norm=norm)

    # Add grid lines
    for i in range(self.size + 1):
        plt.axhline(i - 0.5, color='black', linewidth=1)
        plt.axvline(i - 0.5, color='black', linewidth=1)

    # Add arrows for policy if provided
    if policy is not None:
        arrow_dict = {0: '↑', 1: '→', 2: '↓', 3: '←'}
        for i in range(self.size):
            for j in range(self.size):
                if (i, j) not in self.obstacles and (i, j) != self.goal:
                    action = policy[i, j]
                    plt.text(j, i, arrow_dict[action],
                           ha='center', va='center',
                           fontsize=20, color='red')

    plt.title('Grid World')
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.show()

# Add this method to GridWorld class
GridWorld.render = render