Esta notebook contiene bloques de código útiles para realizar Q-learning en el entorno "Taxi"

In [3]:
#!pip install gymnasium
#!pip install gymnasium[toy-text]
import numpy as np
import random
from taxi_env_extended import TaxiEnvExtended

Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
   ---------------------------------------- 0.0/953.9 kB ? eta -:--:--
   ---------------------------------------- 0.0/953.9 kB ? eta -:--:--
   ---------------------------------------- 10.2/953.9 kB ? eta -:--:--
   ---------------------------------------- 10.2/953.9 kB ? eta -:--:--
   - ------------------------------------- 41.0/953.9 kB 245.8 kB/s eta 0:00:04
   ------ ------------------------------- 153.6/953.9 kB 833.5 kB/s eta 0:00:01
   ------------------- -------------------- 471.0/953.9 kB 2.1 MB/s eta 0:00:01
   -------------------------- ------------- 634.9/953.9 kB 2.4 MB/s eta 0:00:01
   -------------------------------------- - 911.4/953.9 kB 2.9 MB/s eta 0:00:01
   -----------------------------------

In [4]:
env = TaxiEnvExtended()

Obtener la cantidad de estados y acciones

In [5]:
actions = env.action_space.n
states = env.observation_space.n

Inicialización de la tabla Q

In [30]:
Q = np.zeros((states, actions))
Q

array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]])

Obtención de la acción a partir de la tabla Q

In [8]:
def optimal_policy(state, Q):
    action = np.argmax(Q[state])
    return action

Epsilon-Greedy Policy

In [31]:
def epsilon_greedy_policy(state, Q, epsilon=0.1):
    explore = np.random.binomial(1, epsilon)
    if explore:
        action = env.action_space.sample()
        print('explore')
    # exploit
    else:
        action = np.argmax(Q[state])
        print('exploit')
        
    return action

Ejemplo de episodio 

In [32]:
# training function
def train(env, Q, alpha=0.1, gamma=0.9, epsilon=0.1, episodes=10):
    for i in range(episodes):
        obs,_ = env.reset()
        print(obs)
        done = False
        total_reward = 0
        step_count = 0
        while not done:
            state = obs
            action = epsilon_greedy_policy(state, Q, epsilon)
            obs, reward, done, _, _ = env.step(action)
            total_reward += reward
            step_count += 1
            Q[state, action] = Q[state, action] + alpha * (reward + gamma * np.max(Q[obs]) - Q[state, action])
            print('->', state, action, reward, obs, done)
            env.render()
        print('episode', i)
        print('total_reward', total_reward)
        print('total_steps', step_count)

In [33]:
# train the agent
train(env, Q, epsilon=0.5, episodes=10)

343
exploit
-> 343 0 -1 443 False
exploit
-> 443 0 -1 443 False
exploit
-> 443 1 -1 343 False
exploit
-> 343 1 -1 243 False
explore
-> 243 4 -10 243 False
explore
-> 243 2 -1 263 False
exploit
-> 263 0 -1 363 False
explore
-> 363 3 -1 363 False
explore
-> 363 1 -1 263 False
explore
-> 263 3 -1 243 False
explore
-> 243 4 -10 243 False
exploit
-> 243 0 -1 343 False
explore
-> 343 1 -1 243 False
explore
-> 243 4 -10 243 False
exploit
-> 243 1 -1 143 False
exploit
-> 143 0 -1 243 False
exploit
-> 243 3 -1 223 False
explore
-> 223 0 -1 323 False
explore
-> 323 1 -1 223 False
explore
-> 223 2 -1 243 False
explore
-> 243 3 -1 223 False
exploit
-> 223 1 -1 123 False
explore
-> 123 2 -1 123 False
explore
-> 123 5 -10 123 False
explore
-> 123 5 -10 123 False
explore
-> 123 3 -1 103 False
exploit
-> 103 0 -1 203 False
explore
-> 203 5 -10 203 False
exploit
-> 203 0 -1 303 False
exploit
-> 303 0 -1 403 False
exploit
-> 403 0 -1 403 False
explore
-> 403 5 -10 403 False
exploit
-> 403 1 -1 303 False

In [34]:
# test the agent
obs,_ = env.reset()
done = False
while not done:
    state = obs
    action = optimal_policy(state, Q)
    obs, reward, done, _, _ = env.step(action)
    env.render()
    print('->', state, action, reward, obs, done)

-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -10 443 False
-> 443 4 -1