In [1]:
import numpy as np
import gym
import gnwrapper

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten
from tensorflow.keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory

In [2]:
ENV_NAME = 'CartPole-v0'

In [3]:
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [4]:
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 4)                 0         
_________________________________________________________________
dense (Dense)                (None, 16)                80        
_________________________________________________________________
activation (Activation)      (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 16)                272       
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 16)                272       
_________________________________________________________________
activation_2 (Activation)    (None, 16)                0

In [5]:
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(learning_rate=1e-3), metrics=['mae'])

In [None]:
%%time
dqn.fit(gnwrapper.Monitor(env, directory="./cartpole"), nb_steps=50000, visualize=True, verbose=2)

Training for 50000 steps ...




    26/50000: episode: 1, duration: 0.989s, episode steps:  26, steps per second:  26, episode reward: 26.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.538 [0.000, 1.000],  loss: 0.459331, mae: 0.518412, mean_q: 0.111933




    60/50000: episode: 2, duration: 0.454s, episode steps:  34, steps per second:  75, episode reward: 34.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.618 [0.000, 1.000],  loss: 0.272843, mae: 0.550465, mean_q: 0.505310
    79/50000: episode: 3, duration: 0.162s, episode steps:  19, steps per second: 118, episode reward: 19.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.579 [0.000, 1.000],  loss: 0.107739, mae: 0.677224, mean_q: 1.149584
   118/50000: episode: 4, duration: 0.193s, episode steps:  39, steps per second: 202, episode reward: 39.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.564 [0.000, 1.000],  loss: 0.082309, mae: 0.755051, mean_q: 1.363234
   126/50000: episode: 5, duration: 0.039s, episode steps:   8, steps per second: 203, episode reward:  8.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.875 [0.000, 1.000],  loss: 0.061475, mae: 0.852638, mean_q: 1.543159
   146/50000: episode: 6, duration: 0.094s, episode steps:  20, step

   909/50000: episode: 39, duration: 0.180s, episode steps:  39, steps per second: 217, episode reward: 39.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.513 [0.000, 1.000],  loss: 0.443329, mae: 3.863922, mean_q: 7.412047
   960/50000: episode: 40, duration: 0.241s, episode steps:  51, steps per second: 212, episode reward: 51.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.451 [0.000, 1.000],  loss: 0.468748, mae: 4.041516, mean_q: 7.780587
   979/50000: episode: 41, duration: 0.082s, episode steps:  19, steps per second: 230, episode reward: 19.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.421 [0.000, 1.000],  loss: 0.451813, mae: 4.214100, mean_q: 8.134248
   993/50000: episode: 42, duration: 0.062s, episode steps:  14, steps per second: 224, episode reward: 14.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.429 [0.000, 1.000],  loss: 0.349115, mae: 4.254445, mean_q: 8.283470
  1012/50000: episode: 43, duration: 0.101s, episode steps:  19,

  3944/50000: episode: 74, duration: 0.682s, episode steps: 146, steps per second: 214, episode reward: 146.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.534 [0.000, 1.000],  loss: 3.041062, mae: 16.640734, mean_q: 33.997326
  4120/50000: episode: 75, duration: 0.806s, episode steps: 176, steps per second: 218, episode reward: 176.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.528 [0.000, 1.000],  loss: 2.687468, mae: 17.326159, mean_q: 35.397049
  4320/50000: episode: 76, duration: 0.900s, episode steps: 200, steps per second: 222, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.520 [0.000, 1.000],  loss: 2.802815, mae: 18.093483, mean_q: 37.111401
  4502/50000: episode: 77, duration: 0.830s, episode steps: 182, steps per second: 219, episode reward: 182.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.478 [0.000, 1.000],  loss: 2.779619, mae: 19.051878, mean_q: 38.973263
  4664/50000: episode: 78, duration: 0.713s, episode

 10814/50000: episode: 109, duration: 0.853s, episode steps: 200, steps per second: 234, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.510 [0.000, 1.000],  loss: 4.620298, mae: 38.537010, mean_q: 78.208405
 11014/50000: episode: 110, duration: 0.856s, episode steps: 200, steps per second: 234, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.500 [0.000, 1.000],  loss: 6.973832, mae: 38.902481, mean_q: 78.721306
 11214/50000: episode: 111, duration: 0.908s, episode steps: 200, steps per second: 220, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.490 [0.000, 1.000],  loss: 7.685471, mae: 39.032612, mean_q: 79.070557
 11414/50000: episode: 112, duration: 0.880s, episode steps: 200, steps per second: 227, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.490 [0.000, 1.000],  loss: 7.963250, mae: 39.842098, mean_q: 80.659767
 11589/50000: episode: 113, duration: 0.813s, ep

 17716/50000: episode: 144, duration: 0.906s, episode steps: 200, steps per second: 221, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.505 [0.000, 1.000],  loss: 10.550279, mae: 47.437767, mean_q: 95.755974
 17916/50000: episode: 145, duration: 0.977s, episode steps: 200, steps per second: 205, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.510 [0.000, 1.000],  loss: 12.362832, mae: 47.431183, mean_q: 95.690857
 18116/50000: episode: 146, duration: 0.995s, episode steps: 200, steps per second: 201, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.500 [0.000, 1.000],  loss: 12.924763, mae: 47.988945, mean_q: 96.793770
 18316/50000: episode: 147, duration: 0.899s, episode steps: 200, steps per second: 223, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.495 [0.000, 1.000],  loss: 12.110049, mae: 48.431427, mean_q: 97.700409
 18510/50000: episode: 148, duration: 0.897s

 24485/50000: episode: 179, duration: 0.754s, episode steps: 170, steps per second: 225, episode reward: 170.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.535 [0.000, 1.000],  loss: 6.049616, mae: 52.770805, mean_q: 106.562820
 24683/50000: episode: 180, duration: 0.940s, episode steps: 198, steps per second: 211, episode reward: 198.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.520 [0.000, 1.000],  loss: 13.107698, mae: 52.732796, mean_q: 106.215034
 24870/50000: episode: 181, duration: 0.859s, episode steps: 187, steps per second: 218, episode reward: 187.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.524 [0.000, 1.000],  loss: 9.849440, mae: 53.120071, mean_q: 107.072945
 25070/50000: episode: 182, duration: 0.952s, episode steps: 200, steps per second: 210, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.520 [0.000, 1.000],  loss: 12.907976, mae: 53.387596, mean_q: 107.459396
 25270/50000: episode: 183, duration: 0.94

 30823/50000: episode: 214, duration: 0.734s, episode steps: 166, steps per second: 226, episode reward: 166.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.524 [0.000, 1.000],  loss: 6.520683, mae: 51.479252, mean_q: 103.584831
 31011/50000: episode: 215, duration: 0.861s, episode steps: 188, steps per second: 218, episode reward: 188.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.521 [0.000, 1.000],  loss: 9.333827, mae: 50.881317, mean_q: 102.279930
 31185/50000: episode: 216, duration: 0.905s, episode steps: 174, steps per second: 192, episode reward: 174.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.523 [0.000, 1.000],  loss: 9.463153, mae: 51.038868, mean_q: 102.550583
 31372/50000: episode: 217, duration: 1.807s, episode steps: 187, steps per second: 104, episode reward: 187.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.524 [0.000, 1.000],  loss: 4.512249, mae: 50.877693, mean_q: 102.387115
 31567/50000: episode: 218, duration: 0.935s

 36973/50000: episode: 249, duration: 1.079s, episode steps: 200, steps per second: 185, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.515 [0.000, 1.000],  loss: 5.406145, mae: 46.950356, mean_q: 94.426888
 37173/50000: episode: 250, duration: 0.956s, episode steps: 200, steps per second: 209, episode reward: 200.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.535 [0.000, 1.000],  loss: 6.306237, mae: 47.219955, mean_q: 94.932228
 37331/50000: episode: 251, duration: 0.806s, episode steps: 158, steps per second: 196, episode reward: 158.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.519 [0.000, 1.000],  loss: 4.761568, mae: 46.720585, mean_q: 93.959915
 37492/50000: episode: 252, duration: 0.706s, episode steps: 161, steps per second: 228, episode reward: 161.000, mean reward:  1.000 [ 1.000,  1.000], mean action: 0.553 [0.000, 1.000],  loss: 4.482238, mae: 46.144348, mean_q: 92.831444
 37651/50000: episode: 253, duration: 0.711s, ep

In [None]:
dqn.save_weights('cartpole.h5f', overwrite=True)

In [None]:
dqn.test(gnwrapper.Animation(env), nb_episodes=5, visualize=True)