**SARSA(STATE ACTION REWARD)**

In [None]:
import numpy as np
import gym

**Epsilon Greedy Policy**

In [None]:
def eps_greedy(Q, s ,eps=0.1):
    if np.random.uniform(0,1) < eps:
        return np.random.randint(Q.shape[1])
    else:
        return greedy(Q, s)

**Greedy Policy**

In [None]:
def greedy(Q, s):
    return np.argmax(Q[s])

**Policy** **Tesing**

In [None]:
def run_episodes(env, Q, num_episodes=100, to_print=False):
    tot_rew = []
    state = env.reset()
    for _ in range(num_episodes):
        done = False
        game_rew = 0

        while not done:
            next_state, rew, done, _ =env.step(greedy(Q, state))
            state = next_state
            game_rew += rew

            if done:
                state=env.reset()
                tot_rew.append(game_rew)
        if to_print:
          print('Mean score: %.3f of %i of games!'%(np.mean(tot_rew), num_episodes))

    return np.mean(tot_rew)

In [None]:
import numpy as np

def SARSA(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):
    nA = env.action_space.n
    nS = env.observation_space.n

    Q = np.zeros((nS, nA))
    games_reward = []
    test_rewards = []

    for ep in range(num_episodes):
        state = env.reset()
        done = False
        tot_rew = 0

        action = eps_greedy(Q, state, eps)

        while not done:
            next_state, rew, done, _ = env.step(action)
            next_action = eps_greedy(Q, next_state, eps)

            Q[state][action] = Q[state][action] + lr * (rew + gamma * Q[next_state][next_action] - Q[state][action])

            state = next_state
            action = next_action
            tot_rew += rew

            if done:
                games_reward.append(tot_rew)

        if eps > 0.01:
            eps -= eps_decay
        if (ep % 300) == 0:
            test_rew = run_episodes(env, Q, 1000)
            print("Episode:{:5d} Eps:{:2.4f} Rew:{:2.4f}".format(ep, eps, test_rew))
            test_rewards.append(test_rew)

    return Q


In [29]:
if __name__ == '__main__':
    env = gym.make('Taxi-v3')
    print("SARSA")
    Q_SARSA=SARSA(env, lr=.1, num_episodes=5000, eps=0.4, gamma=0.95, eps_decay=0.001)

SARSA
Episode:    0 Eps:0.3990 Rew:-216.1910
Episode:  300 Eps:0.0990 Rew:-197.8520
Episode:  600 Eps:0.0100 Rew:-180.0610
Episode:  900 Eps:0.0100 Rew:-143.9570
Episode: 1200 Eps:0.0100 Rew:-90.3100
Episode: 1500 Eps:0.0100 Rew:-62.3720
Episode: 1800 Eps:0.0100 Rew:-62.7350
Episode: 2100 Eps:0.0100 Rew:-29.8780
Episode: 2400 Eps:0.0100 Rew:-3.3540
Episode: 2700 Eps:0.0100 Rew:3.3040
Episode: 3000 Eps:0.0100 Rew:4.3770
Episode: 3300 Eps:0.0100 Rew:7.4070
Episode: 3600 Eps:0.0100 Rew:7.0800
Episode: 3900 Eps:0.0100 Rew:7.2040
Episode: 4200 Eps:0.0100 Rew:7.8690
Episode: 4500 Eps:0.0100 Rew:7.9690
Episode: 4800 Eps:0.0100 Rew:7.9590
