In [1]:
import numpy as np

In [5]:
class GridWorld:
    def __init__(self):
        self.grid_size = (3, 3)
        self.start_state = (0, 0)
        self.goal_state = (2, 2)
        self.num_actions = 4  # Up, Down, Left, Right
    
    def step(self, state, action):
        """Return the next state and reward given the current state and action."""
        x, y = state

        if action == 0:  # Up
            next_state = (max(0, x - 1), y)
        elif action == 1:  # Down
            next_state = (min(self.grid_size[0] - 1, x + 1), y)
        elif action == 2:  # Left
            next_state = (x, max(0, y - 1))
        elif action == 3:  # Right
            next_state = (x, min(self.grid_size[1] - 1, y))
        else:
            raise ValueError("Invalid action")

        # Check if the agent reaches the goal state
        reward = 10 if next_state == self.goal_state else -1
        return next_state, reward

In [6]:
def td_learning(grid_world, num_episodes, alpha, gamma):
    """Perform Temporal Difference learning to estimate the value function."""
    # Initialize the value function to 0 for all states
    values = np.zeros(grid_world.grid_size)

    for episode in range(num_episodes):
        state = grid_world.start_state

        while state != grid_world.goal_state:
            # Choose a random action (can later be improved with exploration-exploitation)
            action = np.random.choice(grid_world.num_actions)

            # Get the next state and reward from the environment
            next_state, reward = grid_world.step(state, action)

            # TD(0) update rule for the value function
            values[state] += alpha * (reward + gamma * values[next_state] - values[state])

            # Move to the next state
            state = next_state

    return values


# Create a grid world environment
grid_world = GridWorld()

# Perform Temporal Difference learning
num_episodes = 1000
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
values = td_learning(grid_world, num_episodes, alpha, gamma)

# Print the learned value function
print("Value function:")
print(values)

KeyboardInterrupt: 