# Importing all the libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from env import CliffWalking

# Creating the environment

In [None]:
env = CliffWalking()
env.reset()
env.render()

# Creating the Q(s,a) table

In [None]:
action_values = np.zeros((48, 4))

In [None]:
print(action_values)

# Creating the policy

In [None]:
def policy(state, epsilon=0.2):
    if np.random.random() < epsilon:
        return np.random.choice(4)
    else:
        av = action_values[state]
        return np.argmax(av)

In [None]:
print(f"The action chosen at state 0 is {policy(0)}")

# Implementing the algorithm

In [None]:
def sarsa(action_values, policy, episodes=10000, alpha=0.1, gamma=0.99, epsilon=0.2):
    for episode in range(1, episodes + 1):
        state = env.reset()
        done, terminated = False, False
        action = policy(state, epsilon)
        print(episode)
        while not done or terminated:
            next_state, reward, done, terminated = env.step(action)
            next_action = policy(next_state, epsilon)

            qsa = action_values[state][action]
            next_qsa = action_values[next_state][next_action]

            action_values[state][action] = qsa + alpha * (
                reward + gamma * next_qsa - qsa
            )
            state, action = next_state, next_action

In [None]:
sarsa(action_values, policy)

# Showing the results

In [None]:
print(action_values)

In [None]:
def test_agent(policy, episodes=1, epsilon=0):
    env.pygame_init()
    for episode in range(episodes):
        state = env.reset()
        done, terminated = False, False
        while not (done or terminated):
            action = policy(state, epsilon)
            next_state, reward, done, terminated = env.step(action)
            frame = env.render()
            state = next_state
        print(episode + 1)

In [None]:
test_agent(policy)

In [None]:
env.close()