In [27]:
import numpy as np
import gym
import collections
from torch.utils.tensorboard import SummaryWriter

### step1. load the env

In [10]:
env_name = "FrozenLake-v1"
env = gym.make(env_name)

In [11]:
# the observation space for FrozenLake is a 4x4 table, and the observation value represents the cell idx
env.observation_space.n, env.observation_space.sample()

(16, 3)

In [36]:
# the action space for Frozenlake is 4 direction move
env.action_space.n, env.action_space.sample()

(4, 3)

In [30]:
def test(env_name, policy_func=None, step_limit=1000, render=True):
    env = gym.make(env_name, render_mode="human" if render else None)
    obs, info = env.reset()
    for step in range(step_limit):
        # get action
        if policy_func:
            action, _ = policy_func(obs)
        else: # random sample from action space as the default policy
            action = env.action_space.sample()
        # step the env
        obs, reward, terminated, truncated, info = env.step(action)
        # check if game is over
        if terminated or truncated:
            env.close()
            print(f"Game over with {step+1} steps")
            return
    env.close()
    print(f"Time reached limits as {step_limit} steps")

In [14]:
test(env_name)

Game over with 2 steps


### step2. build the Q-learning agent

In [23]:
class QLAgent:
    def __init__(self, env_name, alpha=0.2, gamma=0.9):
        # collections.defaultdict gives the missing key the default value instead of a KeyError
        # where the default value is determined by the constructing param
        
        self.Qtable = collections.defaultdict(float) # self.Qtable[(s, a)] = action_value 
        self.alpha, self.gamma = alpha, gamma
        self.sample_env = gym.make(env_name) # for sampling the data to update Qtable
        self.action_space = [i for i in range(self.sample_env.action_space.n)]
        self.state, info = self.sample_env.reset()
    def policy(self, state):
        # the policy is to search for the best action a which can maximize self.Qtable[(s, a)]
        best_value, best_action = None, None
        for action in self.action_space:
            action_value = self.Qtable[(state, action)]
            if best_value is None or best_value < action_value:
                best_value, best_action = action_value, action
        return best_action, best_value
    def Qvalue_update(self):
        # step1: sampling the updating data as a 4 elem tuple: (s, a, r, next_s)
        s, a, r, next_s = self._sample()
        
        #  step2: update self.Qtable[(s, a)] 
        #  ← (1-alpha) * old_action_value + alpha * [ r + gamma * max_a' self.Qtable[(s', a')] ]
        #  where alpha is the exponential reduction coeffecient, 
        #  and gamma is the reward reduction coeffecient
        #  s' is the next state when applying action a when current state is s, with reward r
        # and a' is the next best action
        best_next_action, best_next_value = self.policy(next_s)
        new_value = r + self.gamma * best_next_value
        old_value = self.Qtable[(s,a)]
        self.Qtable[(s,a)] = (1-self.alpha) * old_value + self.alpha * new_value
        
    def _sample(self):
        # play the sampling env and sample one step data as a 4 elem tuple: (s, a, r, next_s)
        action = self.sample_env.action_space.sample()
        old_state = self.state
        next_state, reward, terminated, truncated, info = self.sample_env.step(action)
        if terminated or truncated:
            self.state, _ = self.sample_env.reset()
        else:
            self.state = next_state
        return (old_state, action, reward, next_state)
    
    def play(self, play_env):
        epi_reward = 0.0
        state, info = play_env.reset()
        while True:
            action, _ = self.policy(state)
            next_state, reward, terminated, truncated, info = play_env.step(action)
            epi_reward += reward
            if terminated or truncated:
                break
            state = next_state
        return epi_reward
    
    def close(self):
        self.sample_env.close()

In [24]:
agent = QLAgent(env_name)

### step3. train with sampling data to update Q table

In [25]:
def train(env_name, agent, writer, test_episodes=20, reward_bound=0.80):
    # train_env is for sampling the data to update Q table
    # test_env is for testing the performance of policy
    test_env = gym.make(env_name)
    iter_idx = 0
    while True:
        # update Q table
        iter_idx += 1
        agent.Qvalue_update()
        # test policy
        test_reward = 0.0
        for _ in range(test_episodes):
            test_reward += agent.play(test_env)
        test_reward /= test_episodes
        writer.add_scalar("reward", test_reward, iter_idx)
        if test_reward > reward_bound:
            print(f"Solved with {iter_idx} steps")
            break
    # close the resources
    writer.close()
    test_env.close()
    agent.close()

In [28]:
writer = SummaryWriter()

In [29]:
train(env_name, agent, writer)

Solved with 17119 steps


### step4. test

In [35]:
test(env_name, policy_func=agent.policy)

Game over with 100 steps
