Esta notebook contiene bloques de código útiles para realizar Q-learning en el entorno "Taxi"

In [1]:
import numpy as np
import random
from taxi_env_extended import TaxiEnvExtended

In [2]:
env = TaxiEnvExtended()

Obtener la cantidad de estados y acciones

In [3]:
actions = env.action_space.n
states = env.observation_space.n

Inicialización de la tabla Q

In [4]:
Q = np.zeros((states, actions))
Q

array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]])

Obtención de la acción a partir de la tabla Q

In [5]:
def optimal_policy(state, Q):
    action = np.argmax(Q[state])
    return action

Epsilon-Greedy Policy

In [6]:
def epsilon_greedy_policy(state, Q, epsilon=0.1):
    explore = np.random.binomial(1, epsilon)
    if explore:
        action = env.action_space.sample()
        print('explore')
    # exploit
    else:
        action = np.argmax(Q[state])
        print('exploit')
        
    return action

Ejemplo de episodio 

In [7]:
obs,_ = env.reset()
print(obs)
done = False
total_reward = 0
step_count = 0
while not done:
    state = obs
    action = epsilon_greedy_policy(state, Q, 0.5)
    obs, reward, done, _, _ = env.step(action)
    total_reward += reward
    step_count += 1
    print('->', state, action, reward, obs, done)
    env.render()
print('total_reward', total_reward)
print('total_steps', step_count)

489
explore
-> 489 1 -1 389 False
explore
-> 389 3 -1 369 False
exploit
-> 369 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 5 -10 469 False
explore
-> 469 5 -10 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 5 -10 469 False
explore
-> 469 4 -10 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 3 -1 469 False
explore
-> 469 0 -1 469 False
explore
-> 469 4 -10 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 3 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
exploit
-> 469 0 -1 469 False
explore
-> 469 4 -10 469 False


In [8]:
from Agent import Agent
agent = Agent(env, gamma=0.9, Qtable=Q, alpha=0.1, epsilon=0.5, policy_func=epsilon_greedy_policy) 

agent.train(250)

186
exploit
-> 186 0 -1 286 False
explore
-> 286 2 -1 286 False
explore
-> 286 4 -10 286 False
exploit
-> 286 0 -1 386 False
exploit
-> 386 0 -1 486 False
explore
-> 486 2 -1 486 False
exploit
-> 486 0 -1 486 False
explore
-> 486 2 -1 486 False
exploit
-> 486 1 -1 386 False
exploit
-> 386 1 -1 286 False
exploit
-> 286 1 -1 186 False
explore
-> 186 2 -1 186 False
exploit
-> 186 1 -1 86 False
exploit
-> 86 0 -1 186 False
explore
-> 186 1 -1 86 False
explore
-> 86 3 -1 66 False
explore
-> 66 2 -1 86 False
explore
-> 86 0 -1 186 False
exploit
-> 186 3 -1 166 False
explore
-> 166 3 -1 146 False
exploit
-> 146 0 -1 246 False
explore
-> 246 0 -1 346 False
exploit
-> 346 0 -1 446 False
explore
-> 446 0 -1 446 False
exploit
-> 446 1 -1 346 False
exploit
-> 346 1 -1 246 False
exploit
-> 246 1 -1 146 False
exploit
-> 146 1 -1 46 False
exploit
-> 46 0 -1 146 False
explore
-> 146 2 -1 166 False
exploit
-> 166 0 -1 266 False
explore
-> 266 0 -1 366 False
exploit
-> 366 0 -1 466 False
explore
-> 466 

In [9]:
agent.Qtable

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [-2.16951465e+00, -2.12360395e+00, -2.32794773e+00,
        -2.16517818e+00, -2.15836808e+00, -9.65903910e+00],
       [-1.65130133e+00, -1.62962307e+00, -1.68610868e+00,
        -1.71769766e+00, -1.58776465e+00, -4.89559908e+00],
       ...,
       [-1.20001605e+00, -1.17981406e+00, -1.12888335e+00,
        -1.30883490e+00, -5.46571283e+00, -4.91476641e+00],
       [-1.20086744e+00, -1.38400768e+00, -1.21673173e+00,
        -1.18302018e+00, -6.12079769e+00, -4.37034958e+00],
       [-1.52466378e-03, -1.42214616e-02, -1.68827500e-01,
         4.70530903e+00, -1.90000000e+00, -1.84622500e+00]])

In [10]:
agent.play(10)

Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
Episode finished: reward=-200, steps=200
