In [1]:
import numpy as np
from tqdm import tqdm
import gym
import math
import random
import matplotlib.pyplot as plt

In [2]:
env = gym.make('Taxi-v3')
num_actions = env.action_space
num_obs = env.observation_space
R, G, B, Y = env.unwrapped.locs
R, G, B, Y = list(R), list(G), list(B), list(Y)
env.reset()

(101, {'prob': 1.0, 'action_mask': array([1, 1, 1, 0, 0, 0], dtype=int8)})

In [3]:
def get_passenger_position(passenger_id):
    if passenger_id == 0:  
        return R
    elif passenger_id == 1:  
        return G
    elif passenger_id == 2:  
        return B
    elif passenger_id == 3:  
        return Y
    
    
def get_state(state):
    row, col, pass_id, dest_id = env.unwrapped.decode(state)
    state = np.asarray([row,col,pass_id,dest_id])
    return state

print(R,G,B,Y)

[0, 0] [0, 4] [4, 0] [4, 3]


In [4]:
curr_state = get_state(env.s)
pass_pose = get_passenger_position(curr_state[2])
dest_pose = get_passenger_position(curr_state[3])
print("Taxi at",curr_state[0:2])
print("Passenger at",dest_pose)
print("Destination",pass_pose)
next_state, reward, done, _, _ = env.step(4)
row, col, pass_id, dest_id = env.unwrapped.decode(next_state)
print(curr_state[2])

Taxi at [1 0]
Passenger at [0, 4]
Destination [0, 0]
0


In [5]:
q_r = np.zeros((500,6))
q_b = np.zeros((500,6))
q_g = np.zeros((500,6))
q_y = np.zeros((500,6))
q = np.zeros((500,10))

In [6]:
# Set the number of episodes and maximum number of steps per episode
EPS = 10000
MAX_STEPS = 100

# Set the learning rate, discount factor, and exploration rate
ALPHA = 0.05
GAMMA = 0.99
exploration_rate = 1.0
MIN_EXP = 0.01
EXP_DECAY = 0.01

In [7]:
# Four different q value functions for the four options
def choose_action_red(state,q_r):
    action = np.argmax(q_r[state])
    return action

def choose_action_blue(state,q_b):
    action = np.argmax(q_b[state])
    return action

def choose_action_green(state,q_g):
    action = np.argmax(q_g[state])
    return action

def choose_action_yellow(state,q_y):
    action = np.argmax(q_y[state])
    return action

def choose_action(q,state):
    if not q[state].any():
        return random.randint(0,7)
    action = np.argmax(q[state])
    
    if np.random.rand() < exploration_rate:
        action = np.random.randint(0,7)
        return action
    return action

    

In [8]:
# Options where policy is greedy wrt the corresponding q Value function
def Red(q_r,state):
    optdone = False
    optact = choose_action_red(state,q_r)
    state = get_state(state)
    if state[0:2] is R:
        optdone = True
    return optact,optdone

def Green(q_g,state):
    optdone = False
    optact = choose_action_green(state,q_g)
    state = get_state(state)
    if state[0:2] is G:
        optdone = True
    return optact,optdone
    
def Yellow(q_y,state):
    optdone = False
    optact = choose_action_green(state,q_y)
    state = get_state(state)
    if state[0:2] is Y:
        optdone = True
    return optact,optdone 

def Blue(q_b,state):
    optdone = False
    optact = choose_action_green(state,q_b)
    state = get_state(state)
    if state[0:2] is B:
        optdone = True
    return optact,optdone 

In [9]:
# Iterate over episodes
for episode in range(EPS):
    state, _ = env.reset()
    done = False
    total_reward = 0
    steps = 0
    while not done and steps < MAX_STEPS:
        steps += 1
        action = choose_action(q,state)
        if action < 6:
            next_state, reward, done, _, _ = env.step(action)
            q[state,action] = q[state, action] 
            + ALPHA * (reward + GAMMA * np.max(q[next_state, :])
                                                                  - q[state, action])
            total_reward += reward
            state = next_state
        
        reward_bar = 0
        if action > 5 and action < 10:
            count = 0
            optdone = False
            current_state = state
            while (optdone == False) and count < 500:
                if action == 6:
                    optact, optdone = Red(q_r,state) 
                    next_state, reward, done, _, _ = env.step(optact)
                    q_r[state,optact] = q_r[state, optact] 
                    + ALPHA * (reward + GAMMA * np.max(q_r[next_state, :])
                                                                  - q_r[state, optact])
                    
                if action == 7:
                    optact, optdone = Green(q_g,state)
                    next_state, reward, done, _, _ = env.step(optact)
                    q_g[state,optact] = q_g[state, optact] 
                    + ALPHA * (reward + GAMMA * np.max(q_g[next_state, :])
                                                                  - q_g[state, optact])
                    
                if action == 8:
                    optact, optdone = Blue(q_b,state)
                    
                    next_state, reward, done, _, _ = env.step(optact)
                    q_b[state,optact] = q_b[state, optact] 
                    + ALPHA * (reward + GAMMA * np.max(q_b[next_state, :])
                                                                  - q_b[state, optact])
                    
                if action == 9:
                    optact, optdone = Yellow(q_y,state)
                    next_state, reward, done, _, _ = env.step(optact)
                    q_y[state,optact] = q_y[state, optact] 
                    + ALPHA * (reward + GAMMA * np.max(q_y[next_state, :])
                                                                  - q_y[state, optact])
                    
                reward_bar = reward_bar + GAMMA*reward
                count += 1
                if optdone == True:
                    q[current_state, action] += ALPHA * (reward_bar 
                            - q[current_state, action] 
                            + GAMMA**count * np.max(q[next_state, :]))
                    print("DONE")
                state = next_state
        

    # Decay the exploration rate
    exploration_rate = MIN_EXP + (1 - MIN_EXP) * np.exp(-EXP_DECAY * episode)

    # Print the total reward for each episode
    print(f"Episode {episode + 1}: Total Reward = {total_reward}")

# Print the final q-table
print("Final q-table:")
print(q)

Episode 1: Total Reward = -353
Episode 2: Total Reward = -272
Episode 3: Total Reward = -321
Episode 4: Total Reward = -316
Episode 5: Total Reward = -363
Episode 6: Total Reward = -364
Episode 7: Total Reward = -335
Episode 8: Total Reward = -404
Episode 9: Total Reward = -375
Episode 10: Total Reward = -240
Episode 11: Total Reward = -256
Episode 12: Total Reward = -286
Episode 13: Total Reward = -292
Episode 14: Total Reward = -315
Episode 15: Total Reward = -234
Episode 16: Total Reward = -237
Episode 17: Total Reward = -298
Episode 18: Total Reward = -318
Episode 19: Total Reward = -240


KeyboardInterrupt: 