In [1]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
class GridWorld:
    """Simple GridWorld environment for SARSA"""
    def __init__(self, size=5):
        self.size = size
        self.start = (0, 0)
        self.goal = (size-1, size-1)
        self.obstacles = [(1, 1), (2, 2), (3, 1)]
        self.state = self.start

    def reset(self):
        """Reset environment to start state"""
        self.state = self.start
        return self.state

    def step(self, action):
        """Execute action and return next_state, reward, done"""
        row, col = self.state

        # Actions: 0=up, 1=down, 2=left, 3=right
        if action == 0:  # up
            row = max(0, row - 1)
        elif action == 1:  # down
            row = min(self.size - 1, row + 1)
        elif action == 2:  # left
            col = max(0, col - 1)
        elif action == 3:  # right
            col = min(self.size - 1, col + 1)

        next_state = (row, col)

        # Check if hit obstacle
        if next_state in self.obstacles:
            next_state = self.state  # Stay in current position
            reward = -1
        elif next_state == self.goal:
            reward = 10
        else:
            reward = -0.1  # Small penalty for each step

        self.state = next_state
        done = (next_state == self.goal)

        return next_state, reward, done

    def get_possible_actions(self):
        """Return list of possible actions"""
        return [0, 1, 2, 3]  # up, down, left, right

In [3]:
class SARSAAgent:
    """SARSA Agent for action-value estimation and policy improvement"""
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
      
        self.env = env
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate

        # Q-table: Q(s, a) - action-value function
        self.Q = defaultdict(lambda: np.zeros(4))

        # Track metrics
        self.episode_rewards = []
        self.episode_lengths = []

In [4]:
def epsilon_greedy_policy(self, state):

    if np.random.random() < self.epsilon:
        # Explore: random action
        return np.random.choice(self.env.get_possible_actions())
    else:
        # Exploit: best action according to Q-values
        return np.argmax(self.Q[state])

# Add this method to SARSAAgent class
SARSAAgent.epsilon_greedy_policy = epsilon_greedy_policy

In [5]:
def sarsa_update(self, state, action, reward, next_state, next_action):
    
    # Current Q-value
    current_q = self.Q[state][action]

    # Next Q-value (using the NEXT action from policy)
    next_q = self.Q[next_state][next_action]

    # SARSA update
    td_target = reward + self.gamma * next_q
    td_error = td_target - current_q
    self.Q[state][action] = current_q + self.alpha * td_error

# Add this method to SARSAAgent class
SARSAAgent.sarsa_update = sarsa_update

In [6]:
def train(self, num_episodes=1000, verbose=True):
    
    for episode in range(num_episodes):
        # Reset environment
        state = self.env.reset()

        # Select initial action using policy
        action = self.epsilon_greedy_policy(state)

        episode_reward = 0
        steps = 0
        done = False

        while not done:
            # Take action, observe reward and next state
            next_state, reward, done = self.env.step(action)

            # Select next action using policy (this is key for SARSA!)
            next_action = self.epsilon_greedy_policy(next_state)

            # SARSA update
            self.sarsa_update(state, action, reward, next_state, next_action)

            # Move to next state and action
            state = next_state
            action = next_action

            episode_reward += reward
            steps += 1

            # Prevent infinite loops
            if steps > 200:
                break

        # Track metrics
        self.episode_rewards.append(episode_reward)
        self.episode_lengths.append(steps)

        # Decay epsilon (exploration rate)
        self.epsilon = max(0.01, self.epsilon * 0.995)

        # Print progress
        if verbose and (episode + 1) % 100 == 0:
            avg_reward = np.mean(self.episode_rewards[-100:])
            avg_length = np.mean(self.episode_lengths[-100:])
            print(f"Episode {episode + 1}/{num_episodes} - "
                  f"Avg Reward: {avg_reward:.2f}, "
                  f"Avg Length: {avg_length:.2f}, "
                  f"Epsilon: {self.epsilon:.3f}")

# Add this method to SARSAAgent class
SARSAAgent.train = train

In [7]:
def get_policy(self):
    """Extract learned policy from Q-values"""
    policy = {}
    for state in self.Q.keys():
        policy[state] = np.argmax(self.Q[state])
    return policy

# Add this method to SARSAAgent class
SARSAAgent.get_policy = get_policy

In [8]:
def visualize_policy(self):
    """Visualize learned policy on grid"""
    policy = self.get_policy()

    # Arrow symbols for actions
    arrows = {0: '↑', 1: '↓', 2: '←', 3: '→'}

    print("\nLearned Policy:")
    print("-" * (self.env.size * 4))

    for i in range(self.env.size):
        row_str = ""
        for j in range(self.env.size):
            state = (i, j)
            if state == self.env.goal:
                row_str += " G "
            elif state in self.env.obstacles:
                row_str += " X "
            elif state in policy:
                row_str += f" {arrows[policy[state]]} "
            else:
                row_str += " · "
        print(row_str)

    print("-" * (self.env.size * 4))
    print("G = Goal, X = Obstacle")

# Add this method to SARSAAgent class
SARSAAgent.visualize_policy = visualize_policy

In [9]:
def plot_training_progress(self):
    """Plot training metrics"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Plot episode rewards
    ax1.plot(self.episode_rewards, alpha=0.3, label='Raw')

    # Moving average
    window = 50
    if len(self.episode_rewards) >= window:
        moving_avg = np.convolve(self.episode_rewards,
                                np.ones(window)/window,
                                mode='valid')
        ax1.plot(range(window-1, len(self.episode_rewards)),
                moving_avg,
                label=f'{window}-Episode Moving Avg',
                linewidth=2)

    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Total Reward')
    ax1.set_title('SARSA: Episode Rewards')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot episode lengths
    ax2.plot(self.episode_lengths, alpha=0.3, label='Raw')

    if len(self.episode_lengths) >= window:
        moving_avg = np.convolve(self.episode_lengths,
                                np.ones(window)/window,
                                mode='valid')
        ax2.plot(range(window-1, len(self.episode_lengths)),
                moving_avg,
                label=f'{window}-Episode Moving Avg',
                linewidth=2)

    ax2.set_xlabel('Episode')
    ax2.set_ylabel('Steps to Goal')
    ax2.set_title('SARSA: Episode Lengths')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

# Add this method to SARSAAgent class
SARSAAgent.plot_training_progress = plot_training_progress

In [10]:
print("=" * 60)
print("SARSA Algorithm - Action-Value Estimation & Policy Improvement")
print("=" * 60)

# Create environment
env = GridWorld(size=5)

# Create SARSA agent
agent = SARSAAgent(
    env=env,
    alpha=0.1,      # Learning rate
    gamma=0.99,     # Discount factor
    epsilon=0.3     # Initial exploration rate
)

print("\nTraining SARSA agent...")
print("-" * 60)

SARSA Algorithm - Action-Value Estimation & Policy Improvement

Training SARSA agent...
------------------------------------------------------------


In [11]:
# Train agent
agent.train(num_episodes=1000, verbose=True)

Episode 100/1000 - Avg Reward: 7.78, Avg Length: 17.03, Epsilon: 0.182
Episode 200/1000 - Avg Reward: 9.07, Avg Length: 9.30, Epsilon: 0.110
Episode 300/1000 - Avg Reward: 9.15, Avg Length: 8.81, Epsilon: 0.067
Episode 400/1000 - Avg Reward: 9.23, Avg Length: 8.48, Epsilon: 0.040
Episode 500/1000 - Avg Reward: 9.28, Avg Length: 8.19, Epsilon: 0.024
Episode 600/1000 - Avg Reward: 9.27, Avg Length: 8.16, Epsilon: 0.015
Episode 700/1000 - Avg Reward: 9.29, Avg Length: 8.12, Epsilon: 0.010
Episode 800/1000 - Avg Reward: 9.29, Avg Length: 8.06, Epsilon: 0.010
Episode 900/1000 - Avg Reward: 9.29, Avg Length: 8.11, Epsilon: 0.010
Episode 1000/1000 - Avg Reward: 9.28, Avg Length: 8.07, Epsilon: 0.010


In [12]:
# Visualize learned policy
agent.visualize_policy()


Learned Policy:
--------------------
 →  →  →  ↓  ↓ 
 ↑  X  →  ↓  ↓ 
 ↓  ←  X  →  ↓ 
 ↓  X  →  ↓  ↓ 
 →  →  →  →  G 
--------------------
G = Goal, X = Obstacle
