Import libraries

In [1]:
import gym
import numpy as np
import time

Initialize the Taxi-v3 Environment

In [2]:
env =  gym.make("Taxi-v3")

Understand the Environment

In [3]:
state_space = env.observation_space
state_space

Discrete(500)

In [4]:
action_space = env.action_space
action_space

Discrete(6)

In [5]:
num_states = state_space.n
num_actions = action_space.n

Define the Policy

In [6]:
# P(s) = a 
policy = np.random.randint(num_actions, size=num_states)
policy

array([4, 2, 2, 4, 3, 4, 4, 2, 5, 0, 3, 1, 1, 1, 2, 1, 4, 1, 3, 0, 5, 2,
       5, 3, 1, 1, 3, 1, 2, 2, 1, 2, 3, 0, 4, 3, 4, 2, 4, 5, 1, 0, 5, 0,
       4, 2, 5, 4, 3, 1, 0, 5, 1, 1, 4, 2, 2, 4, 0, 1, 2, 3, 5, 1, 3, 5,
       4, 0, 3, 4, 2, 2, 1, 1, 5, 4, 5, 5, 4, 5, 0, 5, 2, 0, 3, 1, 4, 0,
       0, 4, 0, 1, 0, 0, 3, 3, 5, 0, 1, 5, 2, 0, 5, 2, 2, 4, 3, 5, 2, 0,
       2, 3, 5, 1, 1, 5, 2, 0, 2, 4, 3, 0, 4, 5, 5, 5, 5, 0, 1, 5, 3, 3,
       0, 2, 0, 0, 2, 3, 3, 2, 5, 2, 1, 5, 4, 5, 4, 4, 5, 2, 5, 3, 2, 1,
       0, 0, 2, 5, 0, 2, 5, 3, 2, 2, 1, 3, 1, 1, 5, 1, 0, 3, 0, 3, 1, 2,
       4, 1, 3, 2, 2, 3, 5, 2, 3, 2, 0, 0, 0, 0, 1, 4, 3, 4, 3, 2, 0, 1,
       0, 0, 1, 3, 0, 0, 0, 4, 0, 0, 5, 1, 4, 4, 2, 2, 3, 5, 3, 1, 2, 2,
       3, 0, 1, 4, 4, 0, 5, 5, 1, 3, 0, 3, 5, 3, 3, 5, 5, 1, 2, 0, 3, 4,
       4, 1, 5, 4, 0, 4, 3, 0, 3, 2, 1, 5, 5, 3, 4, 0, 2, 2, 2, 1, 4, 0,
       2, 3, 4, 4, 5, 5, 2, 4, 0, 4, 0, 0, 3, 1, 3, 2, 1, 2, 1, 4, 4, 4,
       2, 4, 3, 2, 0, 2, 1, 0, 3, 4, 1, 0, 0, 4, 4,

Write a Function that Generates an Episode Given a Policy

In [7]:
def generate_episode(policy, render=False):

  episode = []
  st = env.reset()
  if render:
    env.render()

  while True:

    at = policy[st]
    st1, rt, done, debug = env.step(at)
    episode.append([st, at, rt, st1])

    st = st1 
    
    if render:
      env.render()

    if done:
      break

  return np.array(episode)



Q-Learning

In [8]:
def q_learning(policy, num_episodes=50000):

  # define hyperparameters 

  lr = 0.01
  epsilon = 1 
  epsilon_decay = 0.99
  epsilon_min = 0.0001

  # initalize Q function and save the initial policy
  q_val_func = np.zeros((num_states, num_actions))
  initial_policy = np.copy(policy)

  for i in range(num_episodes):
    

    if epsilon > epsilon_min:
      epsilon = epsilon*epsilon_decay

    else:
      epsilon = epsilon_min

    st = env.reset()

    while True:

      at = policy[st]
      st1, rt, done, debug = env.step(at)

      old_value = q_val_func[st, at]
      new_value = old_value + lr*(rt + np.max(q_val_func[st1]) - old_value)
      q_val_func[st, at] = new_value


      if np.random.random() < epsilon:
        policy[st] = env.action_space.sample()

      else:
        policy[st] = np.argmax(q_val_func[st])

      st = st1

      if done:
        break

  return policy

In [9]:
optimal_policy = q_learning(policy)

In [10]:
generate_episode(optimal_policy, render=True)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[34;1m[43mB[0m[0m: |
+---------+

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[42mB[0m: |
+---------+
  (Pickup)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[42m_[0m: |
|Y| : |B: |
+---------+
  (North)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[42m_[0m: |
| | : | : |
|Y| : |B: |
+---------+
  (North)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[42m_[0m: : |
| | : | : |
|Y| : |B: |
+---------+
  (West)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| :[42m_[0m: : : |
| | : | : |
|Y| : |B: |
+---------+
  (West)
+---------+
|[35mR[0m: | : :G|
| : | : : |
|[42m_[0m: : : : |
| | : | : |
|Y| : |B: |
+---------+
  (West)
+---------+
|[35mR[0m: | : :G|
|[42m_[0m: | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (North)
+---------+
|[35m[42mR[0m[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  

array([[472,   4,  -1, 476],
       [476,   1,  -1, 376],
       [376,   1,  -1, 276],
       [276,   3,  -1, 256],
       [256,   3,  -1, 236],
       [236,   3,  -1, 216],
       [216,   1,  -1, 116],
       [116,   1,  -1,  16],
       [ 16,   5,  20,   0]])

In [11]:
def watch(policy, sleep_duration=5):

  st = env.reset()
  env.render()
  time.sleep(sleep_duration)

  while True:
    at = policy[st]
    st1, rt, done, debug = env.step(at)
    env.render()
    time.sleep(sleep_duration)

    st = st1

    if done:
      break
      


In [12]:
watch(optimal_policy)

+---------+
|R: | : :G|
| : | : : |
| : : : :[43m [0m|
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+

+---------+
|R: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
| : :[43m [0m: : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
| :[43m [0m: : : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
|[43m [0m: : : : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
|[43m [0m| : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (South)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|[34;1m[43mY[0m[0m| : |[35mB[0m: |
+---------+
  (South)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|[42mY[0m| : |[35mB[0m: |
+---------+
  (Pickup)
+---------+
|R: | : :G|
| : | : : |
| :