## SARSA
In this TD learning assignment, we are going to look at SARSA, an on-policy TD learning method.
The environment investigated here is the famous Taxi-v3 environment from OpenAI Gym.

In [None]:
###import the required modules

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

## Cliff walking problem

- The cliff is a gridworld-type, undiscounted, episodic task, with start state `S` and goal state `G`. 
- There are four actions allowed: Up, down, right, and left. 
- Action selection is $\epsilon$-greedy; i.e., occasionally the agent will move in a random direction.
- Reward is -1 on all transitions except those into the region marked 'The Cliff'. 
- Stepping into this region incurs a reward of -100 and sends the agent instantly back to `S`.

Ideally, we would want to train an RL agent that learns the optimal policy and avoids travelling right along the cliff edge. In this assignment, we compare how Q-Learning and SARSA fare with this task.

![image.png](attachment:image.png)

In [None]:
# grid world related constants
GRID_HEIGHT = 4
GRID_WIDTH = 12

EPSILON = 0.1  # probability for exploration
ALPHA = 0.5    # step size
GAMMA = 1      # gamma for Q-Learning and SARSA

UP = 0
DOWN = 1
LEFT = 2
RIGHT = 3
ACTIONS = [UP, DOWN, LEFT, RIGHT] # all possible actions

# initial state action pair values
START = [3, 0]
GOAL = [3, 11]

In [None]:
# required helper methods
def step(state, action):
    i, j = state
    if action == UP:
        next_state = [max(i - 1, 0), j]
    elif action == LEFT:
        next_state = [i, max(j - 1, 0)]
    elif action == RIGHT:
        next_state = [i, min(j + 1, GRID_WIDTH - 1)]
    elif action == DOWN:
        next_state = [min(i + 1, GRID_HEIGHT - 1), j]
    else:
        assert False

    reward = -1
    if (action == DOWN and i == 2 and 1 <= j <= 10) or (
        action == RIGHT and state == START):
        reward = -100
        next_state = START

    return next_state, reward

def epsilon_greedy(state, q_value):
    if np.random.binomial(1, EPSILON) == 1:
        return np.random.choice(ACTIONS)
    else:
        values_ = q_value[state[0], state[1], :]
        return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])

Now we create a function that performs SARSA. It
- Updates Q-values until episode convergence
- Returns total rewards within this episode

Compute the target for SARSA.

`HINT: Look at the update formula for SARSA `

In [None]:
# an episode with Sarsa
# @q_value: values for state action pair, will be updated
# @return: total rewards within this episode
def sarsa(q_value):
    state = START
    action = epsilon_greedy(state, q_value)
    rewards = 0.0
    while state != GOAL:
        next_state, reward = step(state, action)
        next_action = epsilon_greedy(next_state, q_value)
        rewards += reward
        
        # define target for SARSA
        target = None
        # your code here
        
        
        # update q_value
        q_value[state[0], state[1], action] += ALPHA * (target - q_value[state[0], state[1], action])
        
        state = next_state
        action = next_action
    return rewards

Before we proceed, we need to confirm that SARSA converges on a typical gridworld setting. We can do so by running the following code.

In [None]:
episodes = 100 # episodes of each run
rewards_sarsa = np.zeros(episodes)
q_sarsa = np.zeros((GRID_HEIGHT, GRID_WIDTH, 4))

for i in range(episodes):
    rewards_sarsa[i] += sarsa(q_sarsa)

In [None]:
# np.max(rewards_sarsa)
assert np.sum(rewards_sarsa >= np.mean(rewards_sarsa)) >= 75

In [None]:
# episodes of each run
episodes = 500

rewards_sarsa = np.zeros(episodes)

# we take multiple runs to create a smooth curve
# however the optimal policy converges well with a single run
for _ in tqdm(range(20)):
    q_sarsa = np.zeros((GRID_HEIGHT, GRID_WIDTH, 4))
    for i in range(episodes):
        rewards_sarsa[i] += sarsa(q_sarsa)
# averaging over independent runs
rewards_sarsa /= 20

# draw reward curves
plt.plot(rewards_sarsa, label='Sarsa')
plt.xlabel('Episodes')
plt.ylabel('Sum of rewards during episode')
plt.ylim([-100, 0])
plt.legend()
plt.show()

Note that SARSA converges to the safe path here. Compare your results with Q-Learning (last assignment)