# 0. Imports

In [69]:
import gymnasium as gym
import numpy as np
import random
from tqdm import tqdm
import time

# 1. Setup an environment

In [70]:
env = gym.make("Taxi-v3", render_mode="ansi")

# 2. Q-Table Initialisation

In [71]:
# Collect all possible state and action spaces
state_space_size = env.observation_space.n
action_space_size = env.action_space.n

# Define a Q table that encompasses all possible state and action pairs
Q = np.zeros((state_space_size, action_space_size))

# 3. Hyperparameters

In [72]:
hyperparameters = {
    "alpha" : 0.1,
    "gamma" : 0.9,
    "epsilon": 1,
    "epsilon_decay" : 0.995,
    "epsilon_min" : 0.1,
    "episodes": 500,
    "stationary_penalty_factor" : 0.01
}

# 4. Training Loop

In [76]:
def train_taxi(env, 
               Q: np.array,
               episodes: int, 
               alpha: float, 
               gamma: float, 
               epsilon: float, 
               epsilon_decay: float, 
               epsilon_min: float,
               stationary_penalty_factor: float):
    
    """
    Function: Updates the Quality table for a agent over n episodes with provided hyper parameters
    Args:
        env: Gymnasium environment of the taxi agent,
        Q (np.array): The Q-table that maps & tracks the Quality of each state and action pair
        episodes (int): The number of episodes to simualte
        alpha (float): The learning rate
        gamma (float): The discount factor of future rewards
        epsilon (float): The exploration rate 
        epsilon_decay (float): Reduction factor of epsilon post an episode
        epsilon_min (float): The min epsilon that must be maintained
        stationary_penalty_factor (float): Penalty for being in teh same state for a long time
    """

    for episode in tqdm(range(episodes)):
        
        # Reset the environment for a new episode
        state, _ = env.reset()
        done = False
        trajectory = []
        while not done:

            # Trajectory update
            trajectory.append(state)

            # Choose the action based on exploration logic
            if random.uniform(0,1) < epsilon:
                action = env.action_space.sample()

            # Choose the action based on exploitation logic
            # Argmax of the given state
            else:
                action = np.argmax(Q[state])

            # Take the next step based on the action chosen
            next_state, reward, done, truncated, _ = env.step(action)
            done = done
            best_next_action_Q = np.argmax(Q[next_state])

            # Update the reward by a factor for being in the same state
            constant_state = 0
            for state in trajectory[::-1]:
                if state == next_state:
                    constant_state += 1

                else:
                    break

            constant_state_penalty = constant_state * stationary_penalty_factor
            reward = reward - constant_state_penalty

            # Update the Q state and action pair
            Q[state, action] = Q[state, action] + alpha * (reward + gamma * best_next_action_Q - Q[state, action])
            state = next_state

        epsilon = max(epsilon_min, epsilon_decay*epsilon)

    print("Training complete")

    return env, Q

In [79]:
env, Q = train_taxi(env=env,
                    Q=Q,
                    episodes=hyperparameters['episodes'],  
                    alpha=hyperparameters['alpha'],
                    gamma=hyperparameters['gamma'],
                    epsilon=hyperparameters['epsilon'],
                    epsilon_decay=hyperparameters['epsilon_decay'],
                    epsilon_min=hyperparameters['epsilon_min'],
                    stationary_penalty_factor=hyperparameters['stationary_penalty_factor'])

 36%|███▌      | 181/500 [1:18:29<2:18:19, 26.02s/it]  


KeyboardInterrupt: 

In [None]:
state, _ = env.reset()
done = False

print("Taxi navigating...\n")
while not done:
    action = np.argmax(Q[state])
    state, reward, done, truncated, _ = env.step(action)
    print(env.render())  # ASCII animation of taxi world
    time.sleep(0.5)

Taxi navigating...

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (East)

+---------+
|[34;1mR[0m: | : :[35mG[0m|
| : | : : |
| : : : : |
| | : | : |
|[43mY[0m| : |B: |
+---------+
  (Eas

KeyboardInterrupt: 