## Imports and Config

In [2]:
import numpy as np
from tqdm import tqdm
import random

import gymnasium as gym

In [11]:
GAMMA = 0.9
ALPHA = 0.1

# get locations of the colored states
env = gym.make('Taxi-v3')
RED_LOC, GREEN_LOC, YELLOW_LOC, BLUE_LOC = env.unwrapped.locs
env.close()

GRID_SIZE = 25
N_PRIMITIVE_MOVES = 4
N_PRIMITIVE_ACTIONS = N_PRIMITIVE_MOVES + 2

In [3]:
def epsilonGreedyPolicy(q_value_arr, epsilon=0.1):
    if random.random() < epsilon:
        return random.randint(0, len(q_value_arr) - 1)
    else:
        return np.argmax(q_value_arr)

In [4]:
class Option:

    def __init__(
        self,
        state_size: int,
        action_size: int,
        terminal_state: int,
        gamma: float = GAMMA,
        alpha: float = ALPHA,
    ):
        self.q_value = np.zeros((state_size, action_size))
        self.terminal_state = terminal_state
        self.gamma = gamma
        self.alpha = alpha

    def get_action(self, state: int, epsilon: float = 0.1):
        return epsilonGreedyPolicy(self.q_value[state], epsilon)

    def q_update(self, state: int, action: int, reward: float, next_state: int):
        self.q_value[state, action] += self.alpha * (
            reward
            + self.gamma * np.max(self.q_value[next_state])
            - self.q_value[state, action]
        )

    def check_done(self, state: int):
        return state == self.terminal_state

In [5]:
class HRL:

    def __init__(
        self,
        state_size: int,
        action_size: int,
        gamma: float = GAMMA,
        alpha: float = ALPHA,
    ):
        self.q_values = np.zeros((state_size, action_size))
        self.update_freq = np.zeros((state_size, action_size))
        self.gamma = gamma
        self.alpha = alpha

    def update_primitive(self, state: int, action: int, reward: float, next_state: int):

        self.q_values[state, action] += self.alpha * (
            reward
            + self.gamma * np.max(self.q_values[next_state])
            - self.q_values[state, action]
        )
        self.update_freq[state, action] += 1

In [None]:
class SMDP_QLearning(HRL):

    def __init__(self, state_size, action_size, gamma=GAMMA, alpha=ALPHA):
        super().__init__(state_size, action_size, gamma, alpha)
        self.options = []
        for loc in [RED_LOC, GREEN_LOC, YELLOW_LOC, BLUE_LOC]:
            self.options.append(Option(GRID_SIZE, N_PRIMITIVE_MOVES, loc))

    # def update_option_midway(
    #     self, state: int, action: int, reward: float, next_state: int
    # ):
    #     return

    def update_option_end(
        self, state: int, action: int, reward: float, next_state: int, opt_duration: int
    ):
        self.q_values[state, action] += self.alpha * (
            reward
            + (self.gamma**opt_duration) * np.max(self.q_values[next_state])
            - self.q_values[state, action]
        )
        self.update_freq[state, action] += 1

In [3]:
class Trainer:

    def __init__(self, env, hrl):
        self.env = env
        self.hrl = hrl

    def train(self, num_episodes: int = 1000):

        for _ in tqdm(range(num_episodes)):
            state, _ = self.env.reset()
            done = False

            # while episode is not over
            while not done:

                # choose action
                action = epsilonGreedyPolicy(self.hrl.q_values[state])

                # if primitive action
                if action < 4:
                    next_state, reward, is_terminal, if_trunc, _ = self.env.step(action)
                    done = is_terminal or if_trunc
                    self.hrl.update_primitive(state, action, reward, next_state)
                    state = next_state

                # if option
                else:

                    Option = self.hrl.options[action - 4]
                    opt_start_state = state
                    opt_reward, opt_duration, opt_done = 0, 0, False

                    while not opt_done and not done:

                        # choose action
                        opt_action = Option.get_action(state)

                        # take action
                        next_state, reward, is_terminal, if_trunc, _ = self.env.step(
                            opt_action
                        )
                        done = is_terminal or if_trunc

                        # update option's q-values
                        Option.q_update(state, opt_action, reward, next_state)

                        # update reward
                        opt_reward += reward * (self.hrl.gamma**opt_duration)

                        # update q-values of hrl
                        # self.hrl.update_option_midway(
                        #     state, opt_action, reward, next_state
                        # )

                        # update duration
                        opt_duration += 1

                        # update state
                        state = next_state

                        # check if option is done
                        opt_done = Option.check_done(state)

                    # update initial state-action pair, if needed
                    self.hrl.update_option_end(
                        opt_start_state, action, opt_reward, next_state, opt_duration
                    )