In [1]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
class GridWorld:
    """Simple GridWorld environment for SARSA"""
    def __init__(self, size=5):
        self.size = size
        self.start = (0, 0)
        self.goal = (size-1, size-1)
        self.obstacles = [(1, 1), (2, 2), (3, 1)]
        self.state = self.start

    def reset(self):
        """Reset environment to start state"""
        self.state = self.start
        return self.state

    def step(self, action):
        """Execute action and return next_state, reward, done"""
        row, col = self.state

        # Actions: 0=up, 1=down, 2=left, 3=right
        if action == 0:  # up
            row = max(0, row - 1)
        elif action == 1:  # down
            row = min(self.size - 1, row + 1)
        elif action == 2:  # left
            col = max(0, col - 1)
        elif action == 3:  # right
            col = min(self.size - 1, col + 1)

        next_state = (row, col)

        # Check if hit obstacle
        if next_state in self.obstacles:
            next_state = self.state  # Stay in current position
            reward = -1
        elif next_state == self.goal:
            reward = 10
        else:
            reward = -0.1  # Small penalty for each step

        self.state = next_state
        done = (next_state == self.goal)

        return next_state, reward, done

    def get_possible_actions(self):
        """Return list of possible actions"""
        return [0, 1, 2, 3]  # up, down, left, right

In [3]:
class SARSAAgent:
    """SARSA Agent for action-value estimation and policy improvement"""
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
      
        self.env = env
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate

        # Q-table: Q(s, a) - action-value function
        self.Q = defaultdict(lambda: np.zeros(4))

        # Track metrics
        self.episode_rewards = []
        self.episode_lengths = []

In [4]:
def epsilon_greedy_policy(self, state):

    if np.random.random() < self.epsilon:
        # Explore: random action
        return np.random.choice(self.env.get_possible_actions())
    else:
        # Exploit: best action according to Q-values
        return np.argmax(self.Q[state])

# Add this method to SARSAAgent class
SARSAAgent.epsilon_greedy_policy = epsilon_greedy_policy