In [1]:
import numpy as np
import gym
import tensorflow as tf

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

Using TensorFlow backend.


In [2]:
tf.__version__

'1.13.1'

In [3]:
ENV_NAME = 'CartPole-v0'

# Get the environment and extract the number of actions available in the Cartpole problem
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [4]:
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 4)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 16)                80        
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 34        
_________________________________________________________________
activation_2 (Activation)    (None, 2)                 0         
Total params: 114
Trainable params: 114
Non-trainable params: 0
_________________________________________________________________
None


In [5]:
model

<keras.engine.sequential.Sequential at 0x1287169b0>

In [6]:
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, target_model_update=1e-2, 
               policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this slows down training quite a lot. 
#dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)
dqn.fit(env, nb_steps=50000, visualize=False, verbose=2)

Training for 50000 steps ...
Instructions for updating:
Use tf.cast instead.




    79/50000: episode: 1, duration: 0.549s, episode steps: 79, steps per second: 144, episode reward: 79.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.519 [0.000, 1.000], mean observation: 0.060 [-0.402, 0.722], loss: 0.428825, mean_absolute_error: 0.496283, mean_q: 0.052027
   113/50000: episode: 2, duration: 0.060s, episode steps: 34, steps per second: 567, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.529 [0.000, 1.000], mean observation: 0.151 [-0.159, 0.753], loss: 0.351756, mean_absolute_error: 0.445386, mean_q: 0.190670
   163/50000: episode: 3, duration: 0.088s, episode steps: 50, steps per second: 570, episode reward: 50.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.520 [0.000, 1.000], mean observation: 0.082 [-0.295, 0.778], loss: 0.315484, mean_absolute_error: 0.465839, mean_q: 0.317009
   197/50000: episode: 4, duration: 0.060s, episode steps: 34, steps per second: 566, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], 

   744/50000: episode: 37, duration: 0.026s, episode steps: 12, steps per second: 469, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.250 [0.000, 1.000], mean observation: 0.110 [-1.353, 2.130], loss: 0.671093, mean_absolute_error: 2.600737, mean_q: 4.910402
   752/50000: episode: 38, duration: 0.019s, episode steps: 8, steps per second: 423, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.125 [0.000, 1.000], mean observation: 0.137 [-1.350, 2.190], loss: 0.443758, mean_absolute_error: 2.523529, mean_q: 4.900494
   765/50000: episode: 39, duration: 0.031s, episode steps: 13, steps per second: 426, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.231 [0.000, 1.000], mean observation: 0.104 [-1.566, 2.422], loss: 0.539094, mean_absolute_error: 2.632820, mean_q: 5.093865
   774/50000: episode: 40, duration: 0.021s, episode steps: 9, steps per second: 433, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], 

  1033/50000: episode: 66, duration: 0.024s, episode steps: 11, steps per second: 462, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.273 [0.000, 1.000], mean observation: 0.114 [-1.158, 1.884], loss: 0.622828, mean_absolute_error: 3.528227, mean_q: 6.696102
  1047/50000: episode: 67, duration: 0.029s, episode steps: 14, steps per second: 480, episode reward: 14.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.429 [0.000, 1.000], mean observation: 0.082 [-0.979, 1.428], loss: 0.925514, mean_absolute_error: 3.632098, mean_q: 6.783350
  1061/50000: episode: 68, duration: 0.030s, episode steps: 14, steps per second: 468, episode reward: 14.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.429 [0.000, 1.000], mean observation: 0.111 [-0.571, 1.058], loss: 0.665091, mean_absolute_error: 3.595948, mean_q: 6.716099
  1077/50000: episode: 69, duration: 0.032s, episode steps: 16, steps per second: 498, episode reward: 16.000, mean reward: 1.000 [1.000, 1.00

  1987/50000: episode: 98, duration: 0.068s, episode steps: 34, steps per second: 499, episode reward: 34.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.471 [0.000, 1.000], mean observation: -0.058 [-0.891, 0.616], loss: 1.896024, mean_absolute_error: 6.126431, mean_q: 11.564066
  2026/50000: episode: 99, duration: 0.074s, episode steps: 39, steps per second: 524, episode reward: 39.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.487 [0.000, 1.000], mean observation: -0.065 [-0.766, 0.379], loss: 1.034939, mean_absolute_error: 6.101964, mean_q: 11.843329
  2046/50000: episode: 100, duration: 0.039s, episode steps: 20, steps per second: 510, episode reward: 20.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.117 [-0.891, 0.361], loss: 1.555776, mean_absolute_error: 6.334466, mean_q: 12.278546
  2063/50000: episode: 101, duration: 0.034s, episode steps: 17, steps per second: 503, episode reward: 17.000, mean reward: 1.000 [1.0

  2538/50000: episode: 133, duration: 0.039s, episode steps: 17, steps per second: 434, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.529 [0.000, 1.000], mean observation: -0.072 [-1.053, 0.624], loss: 2.725843, mean_absolute_error: 7.563269, mean_q: 14.387551
  2553/50000: episode: 134, duration: 0.030s, episode steps: 15, steps per second: 494, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.533 [0.000, 1.000], mean observation: -0.097 [-1.091, 0.617], loss: 2.835762, mean_absolute_error: 7.690825, mean_q: 14.671184
  2565/50000: episode: 135, duration: 0.024s, episode steps: 12, steps per second: 505, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.583 [0.000, 1.000], mean observation: -0.115 [-1.225, 0.583], loss: 2.819621, mean_absolute_error: 7.755838, mean_q: 14.798855
  2581/50000: episode: 136, duration: 0.030s, episode steps: 16, steps per second: 531, episode reward: 16.000, mean reward: 1.000 [1

  3177/50000: episode: 162, duration: 0.088s, episode steps: 48, steps per second: 544, episode reward: 48.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.500 [0.000, 1.000], mean observation: -0.038 [-1.163, 0.410], loss: 3.407467, mean_absolute_error: 8.594710, mean_q: 16.469498
  3228/50000: episode: 163, duration: 0.089s, episode steps: 51, steps per second: 573, episode reward: 51.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.451 [0.000, 1.000], mean observation: -0.128 [-0.945, 0.329], loss: 3.244341, mean_absolute_error: 8.684105, mean_q: 16.731190
  3287/50000: episode: 164, duration: 0.104s, episode steps: 59, steps per second: 568, episode reward: 59.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.458 [0.000, 1.000], mean observation: -0.072 [-0.882, 0.412], loss: 3.819069, mean_absolute_error: 8.858717, mean_q: 16.961966
  3335/50000: episode: 165, duration: 0.089s, episode steps: 48, steps per second: 538, episode reward: 48.000, mean reward: 1.000 [1

  6073/50000: episode: 191, duration: 0.293s, episode steps: 156, steps per second: 532, episode reward: 156.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.468 [0.000, 1.000], mean observation: -0.247 [-2.002, 0.397], loss: 4.138729, mean_absolute_error: 12.221498, mean_q: 24.064806
  6198/50000: episode: 192, duration: 0.236s, episode steps: 125, steps per second: 530, episode reward: 125.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.456 [0.000, 1.000], mean observation: -0.336 [-1.994, 0.430], loss: 3.642039, mean_absolute_error: 12.470331, mean_q: 24.628428
  6326/50000: episode: 193, duration: 0.250s, episode steps: 128, steps per second: 511, episode reward: 128.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.453 [0.000, 1.000], mean observation: -0.351 [-2.204, 0.532], loss: 4.107232, mean_absolute_error: 12.749136, mean_q: 25.177803
  6471/50000: episode: 194, duration: 0.321s, episode steps: 145, steps per second: 452, episode reward: 145.000, mean rewar

 10545/50000: episode: 220, duration: 0.343s, episode steps: 147, steps per second: 428, episode reward: 147.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.456 [0.000, 1.000], mean observation: -0.386 [-2.362, 0.483], loss: 5.830125, mean_absolute_error: 19.289249, mean_q: 38.791122
 10715/50000: episode: 221, duration: 0.318s, episode steps: 170, steps per second: 535, episode reward: 170.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.465 [0.000, 1.000], mean observation: -0.333 [-2.219, 0.479], loss: 5.675663, mean_absolute_error: 19.544792, mean_q: 39.370735
 10858/50000: episode: 222, duration: 0.291s, episode steps: 143, steps per second: 491, episode reward: 143.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.455 [0.000, 1.000], mean observation: -0.385 [-2.381, 0.380], loss: 6.045901, mean_absolute_error: 19.681993, mean_q: 39.713520
 11007/50000: episode: 223, duration: 0.305s, episode steps: 149, steps per second: 488, episode reward: 149.000, mean rewar

 14865/50000: episode: 249, duration: 0.245s, episode steps: 124, steps per second: 506, episode reward: 124.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.452 [0.000, 1.000], mean observation: -0.476 [-2.418, 0.474], loss: 6.082700, mean_absolute_error: 24.642704, mean_q: 50.220360
 15004/50000: episode: 250, duration: 0.302s, episode steps: 139, steps per second: 461, episode reward: 139.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.453 [0.000, 1.000], mean observation: -0.439 [-2.415, 0.415], loss: 6.434256, mean_absolute_error: 24.741554, mean_q: 50.462799
 15173/50000: episode: 251, duration: 0.406s, episode steps: 169, steps per second: 416, episode reward: 169.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.365 [-2.538, 0.402], loss: 6.950211, mean_absolute_error: 24.767298, mean_q: 50.530823
 15305/50000: episode: 252, duration: 0.282s, episode steps: 132, steps per second: 468, episode reward: 132.000, mean rewar

 19244/50000: episode: 278, duration: 0.308s, episode steps: 171, steps per second: 555, episode reward: 171.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.359 [-2.418, 0.433], loss: 6.987628, mean_absolute_error: 27.997816, mean_q: 57.145035
 19427/50000: episode: 279, duration: 0.319s, episode steps: 183, steps per second: 574, episode reward: 183.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.470 [0.000, 1.000], mean observation: -0.297 [-2.127, 0.569], loss: 5.891976, mean_absolute_error: 28.357197, mean_q: 57.964230
 19593/50000: episode: 280, duration: 0.290s, episode steps: 166, steps per second: 572, episode reward: 166.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.464 [0.000, 1.000], mean observation: -0.369 [-2.409, 0.649], loss: 6.690246, mean_absolute_error: 28.659645, mean_q: 58.518566
 19729/50000: episode: 281, duration: 0.251s, episode steps: 136, steps per second: 543, episode reward: 136.000, mean rewar

 23744/50000: episode: 307, duration: 0.351s, episode steps: 178, steps per second: 507, episode reward: 178.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.466 [0.000, 1.000], mean observation: -0.340 [-2.422, 0.505], loss: 7.768727, mean_absolute_error: 30.563114, mean_q: 61.983025
 23896/50000: episode: 308, duration: 0.292s, episode steps: 152, steps per second: 520, episode reward: 152.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.454 [0.000, 1.000], mean observation: -0.409 [-2.571, 0.465], loss: 7.838018, mean_absolute_error: 30.197214, mean_q: 61.249897
 24084/50000: episode: 309, duration: 0.354s, episode steps: 188, steps per second: 531, episode reward: 188.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.463 [0.000, 1.000], mean observation: -0.321 [-2.522, 0.465], loss: 6.626745, mean_absolute_error: 30.336670, mean_q: 61.721077
 24218/50000: episode: 310, duration: 0.256s, episode steps: 134, steps per second: 523, episode reward: 134.000, mean rewar

 28099/50000: episode: 336, duration: 0.356s, episode steps: 152, steps per second: 427, episode reward: 152.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.454 [0.000, 1.000], mean observation: -0.404 [-2.596, 0.363], loss: 5.856310, mean_absolute_error: 31.501432, mean_q: 63.736149
 28263/50000: episode: 337, duration: 0.318s, episode steps: 164, steps per second: 515, episode reward: 164.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.463 [0.000, 1.000], mean observation: -0.345 [-2.353, 0.576], loss: 9.216780, mean_absolute_error: 31.393921, mean_q: 63.357075
 28444/50000: episode: 338, duration: 0.376s, episode steps: 181, steps per second: 481, episode reward: 181.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.464 [0.000, 1.000], mean observation: -0.335 [-2.405, 0.458], loss: 6.239495, mean_absolute_error: 31.405087, mean_q: 63.380257
 28592/50000: episode: 339, duration: 0.260s, episode steps: 148, steps per second: 569, episode reward: 148.000, mean rewar

 32661/50000: episode: 365, duration: 0.312s, episode steps: 159, steps per second: 509, episode reward: 159.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.459 [0.000, 1.000], mean observation: -0.387 [-2.508, 0.396], loss: 6.134260, mean_absolute_error: 32.085850, mean_q: 64.649902
 32808/50000: episode: 366, duration: 0.263s, episode steps: 147, steps per second: 560, episode reward: 147.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.463 [0.000, 1.000], mean observation: -0.384 [-2.359, 0.415], loss: 5.454414, mean_absolute_error: 31.821579, mean_q: 64.154449
 32981/50000: episode: 367, duration: 0.331s, episode steps: 173, steps per second: 523, episode reward: 173.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.457 [0.000, 1.000], mean observation: -0.352 [-2.697, 0.440], loss: 6.197710, mean_absolute_error: 32.293144, mean_q: 65.286903
 33131/50000: episode: 368, duration: 0.262s, episode steps: 150, steps per second: 573, episode reward: 150.000, mean rewar

 37215/50000: episode: 394, duration: 0.314s, episode steps: 153, steps per second: 487, episode reward: 153.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.477 [0.000, 1.000], mean observation: -0.263 [-1.679, 0.475], loss: 7.262548, mean_absolute_error: 32.268951, mean_q: 65.275116
 37380/50000: episode: 395, duration: 0.301s, episode steps: 165, steps per second: 548, episode reward: 165.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.461 [0.000, 1.000], mean observation: -0.377 [-2.413, 0.477], loss: 4.169618, mean_absolute_error: 32.418903, mean_q: 65.790161
 37580/50000: episode: 396, duration: 0.354s, episode steps: 200, steps per second: 564, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.475 [0.000, 1.000], mean observation: -0.203 [-1.840, 0.465], loss: 5.825756, mean_absolute_error: 32.332832, mean_q: 65.333031
 37736/50000: episode: 397, duration: 0.302s, episode steps: 156, steps per second: 517, episode reward: 156.000, mean rewar

 42123/50000: episode: 423, duration: 0.336s, episode steps: 144, steps per second: 428, episode reward: 144.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.458 [0.000, 1.000], mean observation: -0.365 [-2.191, 0.731], loss: 4.692348, mean_absolute_error: 32.346493, mean_q: 65.653893
 42296/50000: episode: 424, duration: 0.381s, episode steps: 173, steps per second: 454, episode reward: 173.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.468 [0.000, 1.000], mean observation: -0.268 [-2.004, 0.532], loss: 4.668477, mean_absolute_error: 32.470051, mean_q: 65.878265
 42428/50000: episode: 425, duration: 0.245s, episode steps: 132, steps per second: 539, episode reward: 132.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.352 [-1.836, 0.318], loss: 5.759565, mean_absolute_error: 32.387623, mean_q: 65.692978
 42575/50000: episode: 426, duration: 0.259s, episode steps: 147, steps per second: 567, episode reward: 147.000, mean rewar

 47027/50000: episode: 452, duration: 0.336s, episode steps: 190, steps per second: 566, episode reward: 190.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.474 [0.000, 1.000], mean observation: -0.309 [-2.214, 0.742], loss: 5.018560, mean_absolute_error: 33.289719, mean_q: 67.728096
 47227/50000: episode: 453, duration: 0.368s, episode steps: 200, steps per second: 544, episode reward: 200.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.470 [0.000, 1.000], mean observation: -0.267 [-2.154, 0.737], loss: 3.970732, mean_absolute_error: 32.746159, mean_q: 66.661438
 47397/50000: episode: 454, duration: 0.317s, episode steps: 170, steps per second: 537, episode reward: 170.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.465 [0.000, 1.000], mean observation: -0.326 [-2.226, 0.536], loss: 2.642750, mean_absolute_error: 32.936569, mean_q: 67.055870
 47597/50000: episode: 455, duration: 0.363s, episode steps: 200, steps per second: 551, episode reward: 200.000, mean rewar

<keras.callbacks.History at 0x128b09f98>

In [7]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
Episode 1: reward: 200.000, steps: 200
Episode 2: reward: 187.000, steps: 187
Episode 3: reward: 180.000, steps: 180
Episode 4: reward: 200.000, steps: 200
Episode 5: reward: 190.000, steps: 190


<keras.callbacks.History at 0x128716828>