In [2]:
import gym
import numpy as np
import random
from tqdm import tqdm

In [3]:
env = gym.make("Taxi-v3", render_mode="ansi")
env.reset()
print(env.render())

+---------+
|[34;1mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[35m[43mB[0m[0m: |
+---------+




In [4]:
action_space = env.action_space.n
state_space = env.observation_space.n

q_table = np.zeros((state_space, action_space))

alpha = 0.1 # learning rate
gamma = 0.6 # discount rate
epsilon = 0.1 # epsilon

In [5]:
for i in tqdm(range(1, 100001)):
    
    state, _ = env.reset()
    
    done = False
    
    while not done:
        
        if random.uniform(0,1) < epsilon:  # explore %10
            action = env.action_space.sample()
        else: # exploit 
            action = np.argmax(q_table[state])
    
        next_state, reward, done, info, _  = env.step(action)
        
        q_table[state, action] = q_table[state, action] + alpha * (reward + gamma * np.max(q_table[next_state]) - q_table[state, action]) 
        
        state = next_state
        
print("Training finished")

100%|██████████| 100000/100000 [01:09<00:00, 1433.18it/s]

Training finished





In [6]:
# test
total_epoch, total_penalties = 0, 0
episodes = 100

for i in tqdm(range(episodes)):
    
    state, _ = env.reset()
    
    epochs, penalties, reward = 0, 0, 0
    
    done = False
    
    while not done:
        
        action = np.argmax(q_table[state])
    
        next_state, reward, done, info, _  = env.step(action)
                
        state = next_state
        
        if reward == -10:
            penalties += 1
            
        epochs += 1
    
    total_epoch += epochs
    total_penalties += penalties

100%|██████████| 100/100 [00:00<00:00, 917.57it/s]


In [7]:
print("Result after {} episodes".format(episodes))
print("Average timesteps per episode: ", total_epoch / episodes)
print("Average penalties per episode: ", total_penalties / episodes)

Result after 100 episodes
Average timesteps per episode:  12.81
Average penalties per episode:  0.0
