# Importing all the libraries

In [None]:
import numpy as np 
import matplotlib.pyplot as plt 

from env import CliffWalking

# Creating the environment

In [None]:
env = CliffWalking()

# Creating the Q(s,a) table

In [None]:
action_values = np.zeros((48, 4))

# Creating the policy

In [None]:
def policy(state, epsilon=0.2):
    if np.random.random() < epsilon:
        return np.random.randint(4)
    else:
        av = action_values[state]
        return np.random.choice(np.flatnonzero(av == av.max()))

# Implementing the algorithm

In [None]:
def n_step_sarsa(action_values, policy, episodes=1000, alpha=0.1, gamma=0.99, epsilon=0.2, n=8):
    for episode in range(1, episodes + 1):
        state = env.reset()
        action = policy(state, epsilon)
        done, terminated = False, False
        transitions = []
        t = 0
        print(episode)
        while t-n < len(transitions):
            if not done or terminated:
                next_state, reward, done, terminated = env.step(action)
                next_action = policy(next_state, epsilon)
                transitions.append([state, action, reward])

            if t >= n:
                G = (1-done) * action_values[next_state][next_action]
                for state_t, action_t, reward_t in reversed(transitions[t-n:]):
                    G = reward_t + gamma * G
                action_values[state_t][action_t] += alpha * \
                    (G - action_values[state_t][action_t])
            t += 1
            state, action = next_state, next_action

In [None]:
n_step_sarsa(action_values, policy, 1000)

# Show results

In [None]:
print(action_values)

In [None]:
def test_agent(policy, episodes=3, epsilon=0):
    env.pygame_init()
    for episode in range(episodes):
        state = env.reset()
        done, terminated = False, False
        while not (done or terminated):
            action = policy(state, epsilon)
            next_state, reward, done, terminated = env.step(action)
            frame = env.render()
            state = next_state
        print(episode+1)

In [None]:
test_agent(policy)

In [None]:
env.close()