# SARSA Algorithm

Welcome to the SARSA assignment! SARSA (State-Action-Reward-State-Action) is an **on-policy** TD control algorithm. By the end of this notebook, you'll be able to:

* Understand the difference between on-policy and off-policy learning
* Implement the SARSA update rule
* Compare SARSA with Q-Learning
* Understand when to use SARSA vs Q-Learning

## SARSA vs Q-Learning: The Key Difference

**Q-Learning (Off-Policy)**:
$$Q(s,a) \leftarrow Q(s,a) + \alpha \left[r + \gamma \max_{a'} Q(s',a') - Q(s,a)\right]$$

**SARSA (On-Policy)**:
$$Q(s,a) \leftarrow Q(s,a) + \alpha \left[r + \gamma Q(s',a') - Q(s,a)\right]$$

The difference:
- **Q-Learning**: Uses $\max_{a'} Q(s',a')$ (best possible action)
- **SARSA**: Uses $Q(s',a')$ (action actually taken)

This makes SARSA more conservative and safer in risky environments!

<img src="https://miro.medium.com/max/1400/1*vq3cnSPORN6YZAHCf9uAEA.png" style="width:600px;height:300px;">

## Important Note on Submission

Please ensure:
1. No extra print statements
2. No extra code cells
3. Function parameters unchanged
4. No global variables in graded functions

## Table of Contents
- [1 - Packages](#1)
- [2 - SARSA Update](#2)
    - [Exercise 1 - sarsa_update](#ex-1)
- [3 - SARSA Training Loop](#3)
    - [Exercise 2 - train_sarsa](#ex-2)
- [4 - Comparison: SARSA vs Q-Learning](#4)
    - [4.1 - CliffWalking Environment](#4-1)
    - [4.2 - Train Both Algorithms](#4-2)
    - [4.3 - Compare Results](#4-3)

<a name='1'></a>
## 1 - Packages

In [None]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from sarsa_tests import *

%matplotlib inline
plt.rcParams['figure.figsize'] = (12.0, 6.0)

np.random.seed(42)

<a name='2'></a>
## 2 - SARSA Update

The SARSA algorithm follows this sequence:
1. Start in state $s$
2. Choose action $a$ using policy (e.g., ε-greedy)
3. Take action $a$, observe reward $r$ and next state $s'$
4. **Choose next action $a'$ using the same policy** (key difference!)
5. Update: $Q(s,a) \leftarrow Q(s,a) + \alpha[r + \gamma Q(s',a') - Q(s,a)]$
6. $s \leftarrow s'$, $a \leftarrow a'$

**Why is this important?**
- SARSA learns about the policy it's actually following (including exploration)
- More conservative: avoids risky paths during exploration
- Q-Learning learns about the optimal policy (ignoring exploration)

<a name='ex-1'></a>
### Exercise 1 - sarsa_update

Implement the SARSA update rule. Note that unlike Q-Learning, you need the next action $a'$ that will actually be taken.

In [None]:
# GRADED FUNCTION: sarsa_update

def sarsa_update(Q, state, action, reward, next_state, next_action, done, alpha, gamma):
    """
    Update Q-table using SARSA rule.
    
    Arguments:
    Q -- Q-table, numpy array of shape (n_states, n_actions)
    state -- current state
    action -- action taken
    reward -- reward received
    next_state -- next state
    next_action -- next action that WILL be taken (key difference from Q-Learning!)
    done -- boolean, True if next_state is terminal
    alpha -- learning rate
    gamma -- discount factor
    
    Returns:
    Q -- updated Q-table
    td_error -- TD error
    """
    # (approx. 5-7 lines)
    # Step 1: Get current Q-value Q(s,a)
    # Step 2: Calculate TD target
    #         If done: target = reward
    #         Else: target = reward + gamma * Q(s', a')  <- Use next_action, not max!
    # Step 3: Calculate TD error
    # Step 4: Update Q-value
    
    # YOUR CODE STARTS HERE
    
    
    
    
    
    
    # YOUR CODE ENDS HERE
    
    return Q, td_error

In [None]:
# Test your implementation
Q_test = np.zeros((4, 2))
Q_test[1] = [0.5, 0.3]

# SARSA uses the actual next action (not max)
Q_updated, td_error = sarsa_update(
    Q_test.copy(), state=0, action=0, reward=1.0,
    next_state=1, next_action=1, done=False,  # next_action=1 (not max which is 0)
    alpha=0.1, gamma=0.9
)

print(f"SARSA update (using next_action=1):")
print(f"  Q[0,0] = {Q_updated[0, 0]:.4f}")
print(f"  TD error = {td_error:.4f}")
print(f"  (Q-Learning would use max Q[1] = 0.5, SARSA uses Q[1,1] = 0.3)")

# Run the grader
sarsa_update_test(sarsa_update)

**Expected output (approximately):**
```
SARSA update (using next_action=1):
  Q[0,0] = 0.1270
  TD error = 1.2700
```
Note: This is different from Q-Learning (0.145) because SARSA uses the actual next action!

<a name='3'></a>
## 3 - SARSA Training Loop

The SARSA training loop is similar to Q-Learning with one key difference:

```
For each episode:
    Initialize s
    Choose a using ε-greedy from Q(s,·)
    For each step:
        Take action a, observe r, s'
        Choose a' using ε-greedy from Q(s',·)  <- Must choose before update!
        Update Q(s,a) using (s, a, r, s', a')  <- Use actual a', not max
        s ← s', a ← a'
```

<a name='ex-2'></a>
### Exercise 2 - train_sarsa

Implement the SARSA training loop. Pay attention to when you choose the next action!

In [None]:
# Helper function (same as Q-Learning)
def epsilon_greedy_action(Q, state, n_actions, epsilon):
    """Select action using epsilon-greedy policy."""
    if np.random.random() < epsilon:
        return np.random.randint(n_actions)
    else:
        return np.argmax(Q[state])

def initialize_q_table(n_states, n_actions, init_value=0.0):
    """Initialize Q-table."""
    return np.ones((n_states, n_actions)) * init_value

In [None]:
# GRADED FUNCTION: train_sarsa

def train_sarsa(env, n_episodes=1000, alpha=0.1, gamma=0.99,
                epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
                max_steps=100):
    """
    Train SARSA agent.
    
    Arguments:
    env -- Gymnasium environment
    n_episodes -- number of episodes
    alpha -- learning rate
    gamma -- discount factor
    epsilon -- initial exploration rate
    epsilon_decay -- epsilon decay rate
    epsilon_min -- minimum epsilon
    max_steps -- max steps per episode
    
    Returns:
    Q -- trained Q-table
    rewards_history -- list of rewards per episode
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    Q = initialize_q_table(n_states, n_actions)
    
    rewards_history = []
    
    # Training loop (approx. 20-25 lines)
    # Key difference from Q-Learning:
    #   1. Choose initial action BEFORE the step loop
    #   2. Choose next_action BEFORE calling sarsa_update
    #   3. Use next_action in sarsa_update (not max!)
    #   4. Set action = next_action for next iteration
    
    # For each episode:
    #   1. Reset environment, get initial state
    #   2. Choose initial action using epsilon_greedy
    #   3. For each step:
    #      a. Take action, observe reward and next_state
    #      b. Choose next_action using epsilon_greedy (BEFORE update!)
    #      c. Update Q using sarsa_update with next_action
    #      d. state = next_state, action = next_action
    #      e. If done, break
    #   4. Decay epsilon
    #   5. Store episode reward
    
    # YOUR CODE STARTS HERE
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    # YOUR CODE ENDS HERE
    
    return Q, rewards_history

<a name='4'></a>
## 4 - Comparison: SARSA vs Q-Learning

Let's compare SARSA and Q-Learning on the **CliffWalking** environment!

<a name='4-1'></a>
### 4.1 - CliffWalking Environment

```
┌─────────────┐
│ · · · · · G │  G = Goal (+1)
│ S C C C C C │  S = Start
└─────────────┘  C = Cliff (-100)
```

**The Dilemma:**
- **Shortest path**: Walk along the cliff (risky during exploration!)
- **Safe path**: Walk along the top (longer but safer)

**Expected behavior:**
- **Q-Learning**: Learns the optimal (risky) path along the cliff
- **SARSA**: Learns a safer path away from the cliff

<a name='4-2'></a>
### 4.2 - Train Both Algorithms

In [None]:
# Import Q-Learning from previous exercise
def train_q_learning(env, n_episodes=1000, alpha=0.1, gamma=0.99,
                     epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
                     max_steps=100):
    """Q-Learning training (for comparison)."""
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    Q = initialize_q_table(n_states, n_actions)
    rewards_history = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        episode_reward = 0
        
        for step in range(max_steps):
            action = epsilon_greedy_action(Q, state, n_actions, epsilon)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Q-Learning update (uses max)
            current_q = Q[state, action]
            if done:
                target_q = reward
            else:
                target_q = reward + gamma * np.max(Q[next_state])
            Q[state, action] = current_q + alpha * (target_q - current_q)
            
            episode_reward += reward
            state = next_state
            
            if done:
                break
        
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_history.append(episode_reward)
    
    return Q, rewards_history

In [None]:
# Create CliffWalking environment
env = gym.make('CliffWalking-v0')

print("Training both algorithms on CliffWalking...")
print("This will show the key difference between SARSA (safe) and Q-Learning (risky)\n")

# Train Q-Learning
np.random.seed(42)
Q_qlearning, rewards_qlearning = train_q_learning(
    env, n_episodes=500, alpha=0.5, gamma=0.99,
    epsilon=0.1, epsilon_decay=1.0, epsilon_min=0.1
)

# Train SARSA
np.random.seed(42)
Q_sarsa, rewards_sarsa = train_sarsa(
    env, n_episodes=500, alpha=0.5, gamma=0.99,
    epsilon=0.1, epsilon_decay=1.0, epsilon_min=0.1
)

print("Training completed!")

<a name='4-3'></a>
### 4.3 - Compare Results

In [None]:
# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Rewards comparison
window = 20
if len(rewards_qlearning) >= window:
    ma_qlearning = np.convolve(rewards_qlearning, np.ones(window)/window, mode='valid')
    ma_sarsa = np.convolve(rewards_sarsa, np.ones(window)/window, mode='valid')
    
    ax1.plot(range(window-1, len(rewards_qlearning)), ma_qlearning, 
             label='Q-Learning', linewidth=2, alpha=0.8)
    ax1.plot(range(window-1, len(rewards_sarsa)), ma_sarsa,
             label='SARSA', linewidth=2, alpha=0.8)
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Average Reward')
    ax1.set_title('Learning Progress Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

# Final performance comparison
final_ql = np.mean(rewards_qlearning[-100:])
final_sarsa = np.mean(rewards_sarsa[-100:])

ax2.bar(['Q-Learning', 'SARSA'], [final_ql, final_sarsa], 
        color=['#ff7f0e', '#2ca02c'], alpha=0.7)
ax2.set_ylabel('Average Reward (last 100 episodes)')
ax2.set_title('Final Performance')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\nFinal Performance (last 100 episodes):")
print(f"Q-Learning: {final_ql:.2f}")
print(f"SARSA: {final_sarsa:.2f}")
print(f"\nInterpretation:")
print(f"- Q-Learning learns faster but takes more risky path (falls off cliff during training)")
print(f"- SARSA is more conservative, learns safer path away from cliff")
print(f"- During evaluation (epsilon=0), Q-Learning achieves optimal policy")
print(f"- During training (epsilon>0), SARSA gets better average rewards")

env.close()

## Congratulations!

You've successfully implemented SARSA and compared it with Q-Learning! Here's what you've learned:

✅ The difference between on-policy (SARSA) and off-policy (Q-Learning)

✅ How to implement the SARSA update rule

✅ When SARSA is preferable to Q-Learning

✅ How exploration affects learning in both algorithms

### Key Takeaways:

| Algorithm | Type | Update Rule | Behavior |
|-----------|------|-------------|----------|
| **Q-Learning** | Off-policy | Uses $\max Q(s',a')$ | Learns optimal policy, risky during training |
| **SARSA** | On-policy | Uses $Q(s',a')$ (actual) | Learns safe policy, conservative |

**When to use SARSA:**
- When exploration is costly or dangerous
- When you want the learned policy to account for exploration
- In environments with severe penalties for mistakes

**When to use Q-Learning:**
- When you want to learn the optimal policy
- When exploration costs are acceptable
- When off-policy learning is advantageous

### Next Steps:
- Explore Expected SARSA (combines benefits of both)
- Learn about n-step methods
- Move on to Deep RL with DQN!