In [None]:
import collections
import numpy as np
import gymnasium as gym
import plotly.graph_objects as go

from IPython.display import Video

In [None]:
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False)

In [None]:
def play_env(env, agent, record=False):
    terminated = False
    observation, info = env.reset()

    if record:
        env.start_video_recorder()
        env.render()

    while not terminated:
        action = agent.action(observation)

        new_observation, reward, terminated, truncated, info = env.step(action)

        if record:
            env.render()

        agent.observe(observation, new_observation, action, reward, terminated)

        observation = new_observation
    
    agent.estimating()

    return reward

In [None]:
class Sarsa():

    def __init__(self, action_space, gamma, alpha, policy):
        self.gamma  = gamma
        self.alpha = alpha
        self.policy = policy

        self.state_action_values = collections.defaultdict(action_space)

        self.states = []
        self.actions = []
        self.rewards = []

        self.next_action = None

    def action(self, state):
        if self.next_action == None:
            state_action_value = self.state_action_values[state]
            return self.policy(state_action_value)
        else:
            return self.next_action
    
    def observe(self, state, next_state, action, reward, terminated):
        self.next_action = self.policy(self.state_action_values[next_state])

        if terminated:
            self.state_action_values[state][action] += self.alpha * (reward - self.state_action_values[state][action])
        else:
            self.state_action_values[state][action] += self.alpha * (
                reward + self.gamma * self.state_action_values[next_state][self.next_action] - self.state_action_values[state][action]
            )
    
    def estimating(self):
        return

In [None]:
def build_action_space_exploring_start(env):
    return lambda: [1] * env.action_space.n

def epsilon_greedy_policy(state_action_value, epsilon=0.1):
    take_random_action_prob = np.random.uniform(0, 1)

    if take_random_action_prob < epsilon:
        return np.random.randint(0, len(state_action_value))
    else:
        return np.argmax(state_action_value)

agent = Sarsa(action_space=build_action_space_exploring_start(env), gamma=0.99, alpha=0.1, policy=epsilon_greedy_policy)

In [None]:
buffer_size = 10
mean_last_rewards = []
last_rewards = [0] * buffer_size

for i in range(100):
    last_reward = play_env(env, agent)

    last_rewards[i % buffer_size] = last_reward

    if i % buffer_size == 0:
        mean_last_rewards.append(sum(last_rewards) / buffer_size)

In [None]:
len(agent.state_action_values)

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=[i for i in range(len(mean_last_rewards))],
    y=mean_last_rewards,
))

fig.show()

In [None]:
env = gym.make('FrozenLake-v1', render_mode="rgb_array", desc=None, map_name="4x4", is_slippery=False)
video_env = gym.wrappers.RecordVideo(env=env, video_folder="../videos", name_prefix="test-video")

reward = play_env(video_env, agent)

video_env.close()

print(reward)

In [None]:
Video("/home/lucien/Workspace/Project/rl-lab/videos/test-video-episode-0.mp4")