In [63]:
import numpy as np
import gym

In [64]:
ENV_NAME = "Taxi-v3"
env = gym.make(ENV_NAME)
env.reset()
env.render()

+---------+
|[34;1mR[0m: | :[43m [0m:G|
| : | : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+



In [65]:
print("Number of actions: %d" % env.action_space.n)
print("Number of states: %d" % env.observation_space.n)

Number of actions: 6
Number of states: 500


In [66]:
action_size = env.action_space.n
state_size = env.observation_space.n

np.random.seed(123)
env.seed(123)

[123]

In [67]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Embedding, Reshape
from keras.optimizer_v1 import Adam
Adam._name = 'hey'

In [68]:
env.reset()
env.step(env.action_space.sample())[0]

441

In [69]:
import keras
model = keras.Sequential()
model.add(Embedding(500, 6, input_length=1))
model.add(Reshape((6,)))
# model.add(Dense(50, activation='relu'))
# model.add(Dense(50, activation='relu'))
# model.add(Dense(50, activation='relu'))
# model.add(Dense(action_size, activation='linear'))
print(model.summary())

Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_6 (Embedding)     (None, 1, 6)              3000      
                                                                 
 reshape_6 (Reshape)         (None, 6)                 0         
                                                                 
Total params: 3,000
Trainable params: 3,000
Non-trainable params: 0
_________________________________________________________________
None


In [70]:
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

memory = SequentialMemory(limit=50000, window_length=1)
policy = EpsGreedyQPolicy()
dqn = DQNAgent(model=model, nb_actions=action_size, memory=memory, nb_steps_warmup=500, target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
dqn.fit(env, nb_steps=1000000, visualize=False, verbose=1, nb_max_episode_steps=99, log_interval=1000000)

Training for 1000000 steps ...
Interval 1 (0 steps performed)
done, took 19242.742 seconds


<keras.callbacks.History at 0x7feb5cc9d220>

In [71]:
dqn.test(env, nb_episodes=5, visualize=True, nb_max_episode_steps=99)

Testing for 5 episodes ...
+---------+
|R: | : :G|
| : | : : |
| : :[43m [0m: : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
| :[43m [0m: : : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
|[43m [0m: : : : |
| | : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (West)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
|[43m [0m| : | : |
|[34;1mY[0m| : |[35mB[0m: |
+---------+
  (South)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|[34;1m[43mY[0m[0m| : |[35mB[0m: |
+---------+
  (South)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
| | : | : |
|[42mY[0m| : |[35mB[0m: |
+---------+
  (Pickup)
+---------+
|R: | : :G|
| : | : : |
| : : : : |
|[42m_[0m| : | : |
|Y| : |[35mB[0m: |
+---------+
  (North)
+---------+
|R: | : :G|
| : | : : |
|[42m_[0m: : : : |
| | : | : |
|Y| : |[35mB[0m: |
+---------+
  (North)
+---------+
|R: | : :G|


<keras.callbacks.History at 0x7feb5e735eb0>

In [72]:
dqn.save_weights('dqn_{}_weights.h5f'.format("Taxi-v3"), overwrite=True)

In [73]:
dqn.load_weights('dqn_Taxi-v3_weights.h5f')

In [74]:
dqn.test(env, nb_episodes=1, visualize=True, nb_max_episode_steps=1000)

Testing for 1 episodes ...
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[43m [0m: : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (East)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[43m [0m: |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (East)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[43m [0m: |
|Y| : |[34;1mB[0m: |
+---------+
  (South)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[34;1m[43mB[0m[0m: |
+---------+
  (South)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[42mB[0m: |
+---------+
  (Pickup)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : |[42m_[0m: |
|Y| : |B: |
+---------+
  (North)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : :[42m_[0m: |
| | : | : |
|Y| : |B: |
+---------+
  (North)
+---------+
|[35mR[0m: | : :G|
| : | : : |
| : :[42m_[0m: : |
| | : | : |
|Y| : |B: |
+---------+
  (West)
+---------+
|[35mR[0m: | : :G|
| 

<keras.callbacks.History at 0x7feb5e68e490>