In [1]:
# 2. Importing necessary libraries
import gymnasium as gym
import numpy as np
import pandas as pd

In [2]:
# 3.a Enviroment creation
env = gym.make('CliffWalking-v0')

# 3.b Establishing parameters
learning_rate = 0.1
discount = 0.9
epsilon = 0.1
episodes = 1000

In [3]:
# 4. Implementing SARSA

# Auxiliar function to choose the next action
def choose_next_action(Q, state, epsilon):
    if np.random.uniform(0, 1) < epsilon: return env.action_space.sample()  # Exploration
    else: return np.argmax(Q[state])  # Explotation

# 4.a Implementing the SARSA algorithm
def SARSA(env, episodes, learning_rate, discount, epsilon):
    num_actions = env.action_space.n
    num_states = env.observation_space.n

    rewards_per_episode = []
    Q = np.zeros((num_states, num_actions))  # Inicializing the action-value function Q

    # Loop for each episode
    for _ in range(episodes):

        state, _ = env.reset()                          # Initializing S
        action = choose_next_action(Q, state, epsilon)  # Choosing A from S using policy derived from Q
        total_reward = 0
        
        # Loop for each step of episode
        done = False
        while not done:
            next_state, reward, terminated, truncated, _ = env.step(action)  # Taking action A, observing R, S'
            done = terminated or truncated
            
            next_action = choose_next_action(Q, next_state, epsilon)
            Q[state, action] += learning_rate * (reward + (discount * Q[next_state, next_action]) - Q[state, action])
                        
            state = next_state      # S <- S'
            action = next_action    # A <- A'

            total_reward += reward

        rewards_per_episode.append(total_reward)

    # 4.b Returning the rewards per episode and the action-value function Q
    return rewards_per_episode, Q

In [4]:
rewards_per_episode, Q = SARSA(env, episodes, learning_rate, discount, epsilon)

In [5]:
df = pd.DataFrame(rewards_per_episode, columns=['reward'])
df.to_csv('Results/SARSA_rewards.csv', index=False)