In [None]:
"""

    algorithm：SARSA
        Q(s,a) <- Q(s,a) + alpha(r + gamma * Q(s_, a_) - Q(s, a))
    
    environment：FrozenLake-v0

    author: Xinchen Han
    date: 2020/7/25

"""

In [None]:
import os
import time

import gym
import matplotlib.pyplot as plt
import numpy as np

In [None]:
alg_name = 'SARSA'
env_id = 'FrozenLake-v0'
env = gym.make(env_id)

## Set hyperparameters
epsilon = .8
alpha = .8  
gamma = .9  # decay factor
max_episodes = 20000
t0 = time.time()

Q_table = np.zeros([env.observation_space.n, env.action_space.n], dtype = np.float64)
reward_buffer = [0] 

In [None]:
def choose_action(state):

    if (np.random.rand() > epsilon) or ((Q_table[state, :] == 0)).all():
        action = np.random.choice(env.action_space.n)
    else:
        action = np.argmax(Q_table[state,:])
    return action

def SARSA():
    for episode in range(max_episodes):
        state = env.reset()
        action = np.argmax(Q_table[state, :] + np.random.randn(1, env.action_space.n) * (1. / (episode + 1)))
        epi_reward = 0
        done = False
        while not done:
#             action = choose_action(state)  
#             action = np.argmax(Q_table[state, :] + np.random.randn(1, env.action_space.n) * (1. / (episode + 1)))
            state_, reward, done, _ = env.step(action)
            #   greedy policy add a noisy
            action_ = np.argmax(Q_table[state_, :] + np.random.randn(1, env.action_space.n) * (1. / (episode + 1))) 
            Q_table[state][action] = Q_table[state][action] + \
                                        alpha * (reward + gamma * Q_table[state_, action_] - Q_table[state][action])
            state = state_
            action = action_
            epi_reward += reward
        reward_buffer.append(reward_buffer[-1] * 0.9 + epi_reward * 0.1)
        print(
                'Training  | Episode: {}/{}  | Reward:{: .4f} |Running Time: {:.4f}'.format(
                episode + 1, max_episodes, epi_reward, time.time() - t0))
    print("Trained mean reward :{:.4f}".format(np.mean(reward_buffer)))

In [None]:
SARSA()
plt.plot(reward_buffer)