In [1]:
import numpy as np
import gym
import tensorflow as tf

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

Using TensorFlow backend.


In [2]:
tf.__version__

'1.13.1'

In [3]:
ENV_NAME = 'CartPole-v0'

# Get the environment and extract the number of actions available in the Cartpole problem
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [4]:
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 4)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 16)                80        
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 34        
_________________________________________________________________
activation_2 (Activation)    (None, 2)                 0         
Total params: 114
Trainable params: 114
Non-trainable params: 0
_________________________________________________________________
None


In [5]:
model

<keras.engine.sequential.Sequential at 0x7fcde7c98d68>

In [6]:
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, target_model_update=1e-2, 
               policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this slows down training quite a lot. 
#dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)
dqn.fit(env, nb_steps=50000, visualize=False, verbose=2)

Training for 50000 steps ...
Instructions for updating:
Use tf.cast instead.




    79/50000: episode: 1, duration: 1.420s, episode steps: 79, steps per second: 56, episode reward: 79.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.519 [0.000, 1.000], mean observation: 0.060 [-0.402, 0.722], loss: 0.427881, mean_absolute_error: 0.495522, mean_q: 0.052736
   113/50000: episode: 2, duration: 0.088s, episode steps: 34, steps per second: 387, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.529 [0.000, 1.000], mean observation: 0.151 [-0.159, 0.753], loss: 0.348403, mean_absolute_error: 0.444050, mean_q: 0.194054
   163/50000: episode: 3, duration: 0.145s, episode steps: 50, steps per second: 344, episode reward: 50.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.520 [0.000, 1.000], mean observation: 0.082 [-0.295, 0.778], loss: 0.315736, mean_absolute_error: 0.466378, mean_q: 0.320439
   197/50000: episode: 4, duration: 0.086s, episode steps: 34, steps per second: 394, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], m

   718/50000: episode: 30, duration: 0.042s, episode steps: 16, steps per second: 380, episode reward: 16.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.438 [0.000, 1.000], mean observation: 0.104 [-0.557, 1.253], loss: 0.378808, mean_absolute_error: 2.313636, mean_q: 4.356735
   736/50000: episode: 31, duration: 0.058s, episode steps: 18, steps per second: 309, episode reward: 18.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.444 [0.000, 1.000], mean observation: 0.071 [-0.630, 1.190], loss: 0.443225, mean_absolute_error: 2.389580, mean_q: 4.505213
   749/50000: episode: 32, duration: 0.037s, episode steps: 13, steps per second: 356, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: 0.084 [-0.641, 1.169], loss: 0.330261, mean_absolute_error: 2.385354, mean_q: 4.616232
   765/50000: episode: 33, duration: 0.051s, episode steps: 16, steps per second: 315, episode reward: 16.000, mean reward: 1.000 [1.000, 1.00

  1172/50000: episode: 61, duration: 0.048s, episode steps: 15, steps per second: 315, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.467 [0.000, 1.000], mean observation: 0.082 [-0.606, 1.056], loss: 0.842463, mean_absolute_error: 3.907808, mean_q: 7.453444
  1189/50000: episode: 62, duration: 0.052s, episode steps: 17, steps per second: 329, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.471 [0.000, 1.000], mean observation: 0.078 [-0.776, 1.181], loss: 1.017025, mean_absolute_error: 3.951949, mean_q: 7.563557
  1209/50000: episode: 63, duration: 0.053s, episode steps: 20, steps per second: 377, episode reward: 20.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.450 [0.000, 1.000], mean observation: 0.071 [-0.553, 1.036], loss: 1.029505, mean_absolute_error: 4.013451, mean_q: 7.642663
  1226/50000: episode: 64, duration: 0.045s, episode steps: 17, steps per second: 380, episode reward: 17.000, mean reward: 1.000 [1.000, 1.00

  1871/50000: episode: 91, duration: 0.122s, episode steps: 44, steps per second: 362, episode reward: 44.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.523 [0.000, 1.000], mean observation: 0.063 [-0.536, 0.827], loss: 1.502543, mean_absolute_error: 5.433537, mean_q: 10.342675
  1912/50000: episode: 92, duration: 0.108s, episode steps: 41, steps per second: 381, episode reward: 41.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.512 [0.000, 1.000], mean observation: 0.079 [-0.421, 0.751], loss: 1.078176, mean_absolute_error: 5.495960, mean_q: 10.652429
  1952/50000: episode: 93, duration: 0.111s, episode steps: 40, steps per second: 362, episode reward: 40.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.525 [0.000, 1.000], mean observation: 0.064 [-0.393, 0.643], loss: 1.594903, mean_absolute_error: 5.589735, mean_q: 10.636061
  2005/50000: episode: 94, duration: 0.140s, episode steps: 53, steps per second: 378, episode reward: 53.000, mean reward: 1.000 [1.000, 1

  3034/50000: episode: 120, duration: 0.063s, episode steps: 19, steps per second: 303, episode reward: 19.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.474 [0.000, 1.000], mean observation: -0.101 [-0.776, 0.437], loss: 2.213464, mean_absolute_error: 8.467155, mean_q: 16.640816
  3052/50000: episode: 121, duration: 0.050s, episode steps: 18, steps per second: 358, episode reward: 18.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.556 [0.000, 1.000], mean observation: -0.122 [-1.147, 0.550], loss: 2.233250, mean_absolute_error: 8.244177, mean_q: 16.308769
  3073/50000: episode: 122, duration: 0.060s, episode steps: 21, steps per second: 350, episode reward: 21.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.524 [0.000, 1.000], mean observation: -0.062 [-1.165, 0.768], loss: 5.606193, mean_absolute_error: 8.575185, mean_q: 16.537640
  3084/50000: episode: 123, duration: 0.034s, episode steps: 11, steps per second: 322, episode reward: 11.000, mean reward: 1.000 [1

  3531/50000: episode: 149, duration: 0.057s, episode steps: 11, steps per second: 191, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.727 [0.000, 1.000], mean observation: -0.110 [-1.808, 0.993], loss: 6.979950, mean_absolute_error: 9.577744, mean_q: 18.234556
  3540/50000: episode: 150, duration: 0.047s, episode steps: 9, steps per second: 191, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.164 [-2.828, 1.751], loss: 9.101558, mean_absolute_error: 9.907358, mean_q: 18.535810
  3555/50000: episode: 151, duration: 0.083s, episode steps: 15, steps per second: 182, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.600 [0.000, 1.000], mean observation: -0.097 [-1.396, 0.791], loss: 4.086903, mean_absolute_error: 9.428597, mean_q: 18.131456
  3572/50000: episode: 152, duration: 0.073s, episode steps: 17, steps per second: 232, episode reward: 17.000, mean reward: 1.000 [1.0

  5510/50000: episode: 178, duration: 0.230s, episode steps: 77, steps per second: 335, episode reward: 77.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.545 [0.000, 1.000], mean observation: 0.232 [-0.406, 1.272], loss: 4.117792, mean_absolute_error: 11.623552, mean_q: 22.812937
  5590/50000: episode: 179, duration: 0.205s, episode steps: 80, steps per second: 390, episode reward: 80.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.550 [0.000, 1.000], mean observation: 0.253 [-0.368, 1.455], loss: 5.059409, mean_absolute_error: 11.650640, mean_q: 22.723000
  5672/50000: episode: 180, duration: 0.226s, episode steps: 82, steps per second: 362, episode reward: 82.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.463 [0.000, 1.000], mean observation: -0.190 [-1.078, 0.460], loss: 4.800566, mean_absolute_error: 11.768478, mean_q: 23.016130
  5826/50000: episode: 181, duration: 0.403s, episode steps: 154, steps per second: 382, episode reward: 154.000, mean reward: 1.000

 10548/50000: episode: 207, duration: 0.531s, episode steps: 200, steps per second: 377, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: -0.038 [-0.590, 0.530], loss: 6.258637, mean_absolute_error: 20.047333, mean_q: 40.123783
 10748/50000: episode: 208, duration: 0.756s, episode steps: 200, steps per second: 265, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.495 [0.000, 1.000], mean observation: -0.140 [-0.828, 0.543], loss: 8.077374, mean_absolute_error: 20.449593, mean_q: 40.771065
 10948/50000: episode: 209, duration: 0.558s, episode steps: 200, steps per second: 359, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.058 [-0.559, 0.608], loss: 8.003611, mean_absolute_error: 20.663189, mean_q: 41.240185
 11148/50000: episode: 210, duration: 0.520s, episode steps: 200, steps per second: 384, episode reward: 200.000, mean reward

 16348/50000: episode: 236, duration: 0.537s, episode steps: 200, steps per second: 372, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.011 [-0.582, 0.550], loss: 11.185594, mean_absolute_error: 28.661762, mean_q: 57.678761
 16548/50000: episode: 237, duration: 0.614s, episode steps: 200, steps per second: 326, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.078 [-0.558, 0.598], loss: 12.355051, mean_absolute_error: 28.770132, mean_q: 57.832832
 16748/50000: episode: 238, duration: 0.566s, episode steps: 200, steps per second: 353, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.132 [-0.400, 0.862], loss: 9.980508, mean_absolute_error: 28.949112, mean_q: 58.338573
 16948/50000: episode: 239, duration: 0.531s, episode steps: 200, steps per second: 377, episode reward: 200.000, mean rewar

 22148/50000: episode: 265, duration: 0.523s, episode steps: 200, steps per second: 382, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.141 [-0.356, 0.902], loss: 16.150740, mean_absolute_error: 33.654934, mean_q: 67.821648
 22348/50000: episode: 266, duration: 0.516s, episode steps: 200, steps per second: 388, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.167 [-0.535, 1.106], loss: 9.832590, mean_absolute_error: 33.668777, mean_q: 68.089455
 22548/50000: episode: 267, duration: 0.515s, episode steps: 200, steps per second: 389, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.001 [-0.684, 0.747], loss: 12.088122, mean_absolute_error: 33.950562, mean_q: 68.559227
 22748/50000: episode: 268, duration: 0.518s, episode steps: 200, steps per second: 386, episode reward: 200.000, mean reward

 27948/50000: episode: 294, duration: 0.516s, episode steps: 200, steps per second: 388, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.133 [-0.507, 0.862], loss: 12.722240, mean_absolute_error: 36.713696, mean_q: 74.196709
 28148/50000: episode: 295, duration: 0.523s, episode steps: 200, steps per second: 382, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.092 [-0.438, 0.677], loss: 11.009936, mean_absolute_error: 36.704666, mean_q: 74.188507
 28348/50000: episode: 296, duration: 0.537s, episode steps: 200, steps per second: 372, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.142 [-0.630, 0.917], loss: 11.985439, mean_absolute_error: 36.851574, mean_q: 74.428322
 28548/50000: episode: 297, duration: 0.530s, episode steps: 200, steps per second: 377, episode reward: 200.000, mean rewar

 33748/50000: episode: 323, duration: 0.518s, episode steps: 200, steps per second: 386, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.141 [-0.831, 0.876], loss: 13.281450, mean_absolute_error: 37.765968, mean_q: 76.062294
 33948/50000: episode: 324, duration: 0.519s, episode steps: 200, steps per second: 386, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.087 [-0.636, 0.741], loss: 16.347496, mean_absolute_error: 37.758705, mean_q: 75.951546
 34148/50000: episode: 325, duration: 0.508s, episode steps: 200, steps per second: 393, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.515 [0.000, 1.000], mean observation: 0.199 [-0.663, 1.301], loss: 11.962851, mean_absolute_error: 37.700581, mean_q: 75.967682
 34348/50000: episode: 326, duration: 0.538s, episode steps: 200, steps per second: 372, episode reward: 200.000, mean rewar

 39548/50000: episode: 352, duration: 0.518s, episode steps: 200, steps per second: 386, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.019 [-0.520, 0.517], loss: 8.746973, mean_absolute_error: 38.340191, mean_q: 77.145226
 39748/50000: episode: 353, duration: 0.521s, episode steps: 200, steps per second: 384, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.510 [0.000, 1.000], mean observation: 0.153 [-0.744, 0.915], loss: 10.580903, mean_absolute_error: 38.317482, mean_q: 77.138794
 39948/50000: episode: 354, duration: 0.530s, episode steps: 200, steps per second: 377, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.081 [-0.554, 0.531], loss: 12.073756, mean_absolute_error: 38.285923, mean_q: 77.003365
 40148/50000: episode: 355, duration: 0.516s, episode steps: 200, steps per second: 388, episode reward: 200.000, mean reward

 45348/50000: episode: 381, duration: 0.545s, episode steps: 200, steps per second: 367, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.070 [-0.607, 0.538], loss: 12.858491, mean_absolute_error: 38.773956, mean_q: 77.852013
 45548/50000: episode: 382, duration: 0.678s, episode steps: 200, steps per second: 295, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.505 [0.000, 1.000], mean observation: 0.062 [-0.580, 0.642], loss: 8.389922, mean_absolute_error: 38.635986, mean_q: 77.596832
 45748/50000: episode: 383, duration: 0.622s, episode steps: 200, steps per second: 322, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: 0.048 [-0.591, 0.787], loss: 11.833396, mean_absolute_error: 38.650829, mean_q: 77.518578
 45948/50000: episode: 384, duration: 0.549s, episode steps: 200, steps per second: 364, episode reward: 200.000, mean reward

<keras.callbacks.History at 0x7fcde42f1d68>

In [7]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
Episode 1: reward: 200.000, steps: 200
Episode 2: reward: 200.000, steps: 200
Episode 3: reward: 200.000, steps: 200
Episode 4: reward: 200.000, steps: 200
Episode 5: reward: 200.000, steps: 200


<keras.callbacks.History at 0x7fcde7cce6a0>