In [1]:
import numpy as np
import random
from tqdm import tqdm

In [2]:
class FrozenLake:
    def __init__(self, nrow: int, ncol: int, holes: list[int]):
        self.nrow = nrow
        self.ncol = ncol
        self.n_states = nrow * ncol + 1
        # 0 for a empty slot, 1 for the terminal, -1 for a hole, 2 for an absorption
        self.states = np.zeros(self.n_states, dtype=np.int32)
        self.states[holes] = -1
        self.states[self.n_states - 2] = 1
        self.states[self.n_states - 1] = 2
        self.absorption = self.n_states - 1
        # for action: 0 for up, 1 for down, 2 for left, 3 for right.
        self.n_actions = 4
        # Initialize transition matrix and reward
        self.transition = np.zeros(
            (self.n_states, self.n_actions, self.n_states), dtype=np.float64
        )
        self.reward = np.zeros(
            (self.n_states, self.n_actions), dtype=np.float64
        )
        for action in range(self.n_actions):
            self.transition[self.absorption, action, self.absorption] = 1
        
        for x in range(nrow):
            for y in range(ncol):
                state = x * ncol + y
                if self.states[state] != 0:
                    if self.states[state] == -1:
                        reward = -100
                    elif self.states[state] == 1:
                        reward = 1000
                    # Terminate
                    for action in range(self.n_actions):
                        self.transition[state, action, self.absorption] = 1
                        self.reward[state, action] = reward
                    continue
                can_up = can_down = can_left = can_right = True
                if x == 0:
                    can_up = False
                if x == nrow - 1:
                    can_down = False
                if y == 0:
                    can_left = False
                if y == ncol - 1:
                    can_right = False
                for action in range(self.n_actions):
                    if can_up and action != 1:
                        next_state = (x - 1) * ncol + y
                        self.transition[state, action, next_state] = 1
                    if can_down and action != 0:
                        next_state = (x + 1) * ncol + y
                        self.transition[state, action, next_state] = 1
                    if can_left and action != 3:
                        next_state = x * ncol + y - 1
                        self.transition[state, action, next_state] = 1
                    if can_right and action != 2:
                        next_state = x * ncol + y + 1
                        self.transition[state, action, next_state] = 1
                    if (s := np.sum(self.transition[state, action])) != 0:
                        self.transition[state, action] *= 1 / s
    
    def step(self, state: int, action: int) -> tuple[int, float, bool]:
        """
        return (state, reward, is_end)
        """
        reward = self.reward[state, action]
        if self.states[state] != 0:
            return (-1, reward, True)
        distribution = self.transition[state, action]
        next_state = random.choices(np.arange(self.n_states), distribution)[0]
        return (next_state, reward, False)


In [3]:
n_states = 17
n_actions = 4
epsilon = 0.2
gamma = 0.95
alpha = 0.1

In [4]:
env = FrozenLake(4, 4, [5, 7, 11, 12])

In [5]:
def espilon_greedy(state: int, epsilon: int, q_table: np.ndarray):
    if random.random() < epsilon:
        action = random.choice(list(range(n_actions)))
    else:
        action = np.argmax(q_table[state])
    return action

In [6]:
q_table = np.random.random((n_states, n_actions))
q_table

array([[0.25491036, 0.77544495, 0.24071575, 0.0845058 ],
       [0.78742329, 0.99947557, 0.81749056, 0.25958351],
       [0.66585525, 0.82939271, 0.41254837, 0.14451091],
       [0.06202279, 0.30303769, 0.4470478 , 0.42327232],
       [0.78664478, 0.32911941, 0.37687248, 0.1520113 ],
       [0.0331673 , 0.29848245, 0.74189371, 0.11847933],
       [0.64723036, 0.25703982, 0.24822253, 0.77120818],
       [0.52113806, 0.8753327 , 0.63794396, 0.85828247],
       [0.06310311, 0.40367328, 0.15411999, 0.64543417],
       [0.96704758, 0.98239206, 0.81370209, 0.04766594],
       [0.62361502, 0.32431166, 0.04254495, 0.49339122],
       [0.35481324, 0.84851958, 0.67491983, 0.15771248],
       [0.27223153, 0.8929767 , 0.88813526, 0.26147043],
       [0.36483432, 0.6726103 , 0.11192658, 0.79139059],
       [0.25512323, 0.00913324, 0.47924283, 0.84423774],
       [0.39702672, 0.1508288 , 0.9080625 , 0.8722814 ],
       [0.31122746, 0.15908299, 0.04615175, 0.09967401]])

In [11]:
n_epochs = 10000
max_steps = 10000

for epoch in tqdm(range(n_epochs)):
    state = 0
    action = espilon_greedy(state, epsilon, q_table)
    is_end = False
    while not is_end:
        next_state, reward, is_end = env.step(state, action)
        next_action = espilon_greedy(next_state, epsilon, q_table)
        q_table[state, action] += \
            alpha * (reward + gamma * q_table[next_state, next_action] - q_table[state, action])
        state, action = next_state, next_action


100%|██████████| 10000/10000 [00:00<00:00, 13031.14it/s]


In [12]:
state

-1

In [9]:
policy = np.argmax(q_table, axis=1)
np.array(list(map(lambda x: "^v<>"[x], policy))[:-1]).reshape(4, 4)

array([['<', '^', '<', '^'],
       ['<', '>', '>', '<'],
       ['^', 'v', '<', '<'],
       ['<', '>', 'v', '<']], dtype='<U1')

In [10]:
q_table

array([[ 2.00553229e+01,  2.38286226e+01,  6.75698429e+01,
         2.43956466e+01],
       [ 3.65043072e+01, -3.21435440e+01, -5.10137734e+01,
        -4.16584977e+01],
       [ 5.15217411e+00,  8.71519675e+00,  3.60204409e+01,
         9.80962294e+00],
       [ 1.73179613e+01, -4.91583226e+01, -3.42472315e+01,
        -4.46582798e+01],
       [-3.78497279e+01, -1.75656474e+00,  7.97111902e+01,
        -3.64976142e+00],
       [-9.97352680e+01, -9.97537708e+01, -9.97371259e+01,
        -9.97057157e+01],
       [-4.00420187e+01, -1.49549127e+01,  1.52948613e+01,
         2.98644523e+01],
       [-9.54896687e+01, -9.58886755e+01, -9.50349953e+01,
        -9.58747769e+01],
       [ 1.15808105e+02,  2.51102180e+01, -3.81810183e+01,
         3.86749641e+01],
       [ 3.45680774e+01,  2.11287107e+02,  6.48619070e+01,
         1.35707072e+02],
       [ 3.74161243e+01,  1.14592567e+02,  2.43852065e+02,
         8.68149013e+01],
       [-9.08493396e+01, -9.17020965e+01, -9.08325442e+01,
      