In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

In [2]:
class GridWorld:
    """Simple grid world environment for TD(Î») learning."""

    def __init__(self, size=5, goal_state=(4, 4), obstacle_states=None):
        self.size = size
        self.goal_state = goal_state
        self.obstacle_states = obstacle_states if obstacle_states else [(2, 2)]
        self.current_state = (0, 0)
        self.actions = ['up', 'down', 'left', 'right']

    def reset(self):
        """Reset environment to starting state."""
        self.current_state = (0, 0)
        return self.current_state

    def step(self, action):
        """Take action and return next state, reward, and done flag."""
        row, col = self.current_state

        # Determine next state based on action
        if action == 'up':
            next_state = (max(0, row - 1), col)
        elif action == 'down':
            next_state = (min(self.size - 1, row + 1), col)
        elif action == 'left':
            next_state = (row, max(0, col - 1))
        elif action == 'right':
            next_state = (row, min(self.size - 1, col + 1))
        else:
            next_state = self.current_state

        # Check if next state is an obstacle
        if next_state in self.obstacle_states:
            next_state = self.current_state
            reward = -1
        elif next_state == self.goal_state:
            reward = 10
        else:
            reward = -0.1

        self.current_state = next_state
        done = (next_state == self.goal_state)

        return next_state, reward, done

    def get_state_index(self, state):
        """Convert 2D state to 1D index."""
        return state[0] * self.size + state[1]

    def get_state_from_index(self, index):
        """Convert 1D index to 2D state."""
        return (index // self.size, index % self.size)