### On-Policy TD Learning: SARSA

**Key Formula**: $Q(s,a) \leftarrow Q(s,a) + \alpha[r + \gamma Q(s',a') - Q(s,a)]$

**SARSA** (State-Action-Reward-State-Action) is a temporal difference learning algorithm that learns Q-values while following the same policy it's improving.

In [1]:
import numpy as np
import gymnasium as gym

#### Epsilon-Greedy Action Selection

Implement $\epsilon$-greedy policy: explore with probability $\epsilon$, exploit (choose best action) with probability $1-\epsilon$.

In [2]:
def epsilon_greedy(Q, state, epsilon, env):

    # Exploration
    if np.random.rand() < epsilon:
        return env.action_space.sample()

    # Expliotaton
    else:
        return np.argmax(Q[state])

#### SARSA Algorithm

Implement on-policy SARSA: select action using current policy, update Q-values using the same policy's next action.

In [3]:
def sarsa(env,
         num_episodes=10000,
         alpha=0.1,
         gamma=0.99,
         epsilon=0.1):

    Q = np.zeros((env.observation_space.n, env.action_space.n))

    for ep in range(num_episodes):
        state, _ = env.reset()
        action = epsilon_greedy(Q, state, epsilon, env)

        done=False
        while not done:
            next_state, reward, done, _, _ = env.step(action)

            next_action = epsilon_greedy(Q, next_state, epsilon, env)

            Q[state, action] = Q[state, action] + alpha * (reward + gamma * Q[next_state, next_action]- Q[state, action])

            state, action = next_state, next_action

    return Q

#### Train SARSA Agent

Run SARSA on stochastic FrozenLake environment and extract learned policy from Q-values.

In [4]:
env = gym.make("FrozenLake-v1", is_slippery=True)
Q_sarsa = sarsa(env)

policy = np.argmax(Q_sarsa, axis=-1)
print(policy)

[0 3 0 0 0 0 0 0 3 1 0 0 0 2 1 0]


#### Evaluate Learned Policy

Test the learned policy's performance by measuring success rate over multiple episodes.

In [5]:
def test_policy(policy, env, num_episodes=500):
    success_count = 0

    for _ in range(num_episodes):
        state, _ = env.reset()
        done = False

        while not done:
            action = policy[state]
            state, reward, done, _, _ = env.step(action)

            if done and reward == 1.0:  # Reached the goal
                success_count += 1

    success_rate = success_count / num_episodes
    print(f"Policy Success Rate: {success_rate * 100:.2f}%")

test_policy(policy, env)

Policy Success Rate: 79.80%
