In [75]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

## Prepare the environment

In [93]:
class FrozenLake:
    def __init__(
        self, 
        max_steps:int=16, 
        is_slippery:bool=True, 
        render:bool=True, 
        custom_reward: bool=True
    ):
        self.max_steps = max_steps
        self.render = render
        self.custom_reward = custom_reward
        self.frozen_lake = gym.make(
            'FrozenLake-v1',
            desc=None,
            map_name="4x4",
            is_slippery=is_slippery,
            render_mode='rgb_array' if render else None
        )

    def generate_episode(
        self, 
        policy: np.ndarray, 
        epsilon:float=0.1
    ):
        state, _ = self.frozen_lake.reset()
        if not self.custom_reward:
            reward = 0.0
        else:
            reward = -np.sqrt(18)
        trajectory = []
        # Render RGB trajectory to observe the states visually
        if self.render:
            render = [self.frozen_lake.render()]
            
        terminated = False
        while not terminated:
            if np.random.rand() <= epsilon:
                action = np.random.randint(0, self.frozen_lake.action_space.n)
            else:
                action = np.argmax(policy[state])
            new_state, new_reward, terminated, _, _ = self.frozen_lake.step(action)
            if self.custom_reward:
                x = new_state / 4
                y = new_state % 4
                new_reward = -np.sqrt((x-3)**2 + (y-3)**2)

            trajectory.append({'reward':reward, 'state':state, 'action':action})

            # Render RGB trajectory to observe the states visually
            if self.render:
                render.append(self.frozen_lake.render())
                
            reward = new_reward
            state = new_state

            if self.max_steps <= len(trajectory) - 1:
                break

        trajectory.append({'reward':reward})
        if self.render:
            return trajectory, render
        return trajectory

    def generate_video(self, frames, output_name:str='output.mp4', fps:float=1.5):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video_writer = cv2.VideoWriter(output_name, fourcc, fps, (256, 256))
        for frame in render:
            frame = frame[:, :, ::-1]
            video_writer.write(frame)
        video_writer.release()

## On-policy Monte Carlo Control

In [96]:
epsilon = 0.1
num_actions = 4
state_set = 16
gamma = 1.0
render_every_episodes = 1000
episodes = 1000000
policy = np.ones([state_set, num_actions], dtype=np.float32) * epsilon / num_actions
policy[:, 0] = 1 - epsilon + epsilon / num_actions 
Q = np.zeros([state_set, num_actions], dtype=np.float32)
N = np.zeros([state_set, num_actions], dtype=np.float32)
frozen_lake = FrozenLake()

with tqdm(range(episodes)) as prog:
    for episode in prog:
        trajectory, frames = frozen_lake.generate_episode(policy=policy, epsilon=epsilon)
        G = 0.0

        for i in reversed(range(len(trajectory) - 1)):
            reward = trajectory[i + 1]['reward']
            state = trajectory[i]['state']
            action = trajectory[i]['action']
            G = gamma * G + reward

            is_visited = any([t['state'] == state and t['action'] == action for t in trajectory[:i]])
            if not is_visited:
                state, action = int(state), int(action)
                N[state, action] += 1
                Q[state, action] = Q[state, action] + (1 / N[state, action]) * (G - Q[state, action])
                greedy_action = np.argmax(Q[state])
                for a in range(num_actions):
                    if a == greedy_action:
                        policy[state, a] = 1 - epsilon + epsilon / num_actions
                    else:
                        policy[state, a] = epsilon / num_actions

        if episode % render_every_episodes == 0:
            frozen_lake.generate_video(frames, fps=2)
        prog.set_postfix({'G': np.mean(Q)})

  2%|▍                      | 20362/1000000 [01:23<1:06:44, 244.61it/s, G=-5.46]


KeyboardInterrupt: 