In [9]:
import numpy as np
import gym
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

In [2]:
ENV_NAME = 'CartPole-v0'
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [3]:
print(nb_actions)

2


In [5]:
model = Sequential()
model.add(Flatten(input_shape=(1, ) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 4)                 0         
_________________________________________________________________
dense_3 (Dense)              (None, 16)                80        
_________________________________________________________________
activation_2 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 34        
_________________________________________________________________
activation_3 (Activation)    (None, 2)                 0         
Total params: 114
Trainable params: 114
Non-trainable params: 0
_________________________________________________________________
None


In [10]:
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
              target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

In [11]:
dqn.fit(env, nb_steps=5000, visualize=True, verbose=2)

Training for 5000 steps ...




   23/5000: episode: 1, duration: 5.025s, episode steps: 23, steps per second: 5, episode reward: 23.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.522 [0.000, 1.000], mean observation: -0.095 [-1.161, 0.401], loss: 0.529258, mean_absolute_error: 0.596812, mean_q: 0.249399
   33/5000: episode: 2, duration: 0.165s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.154 [-1.520, 2.585], loss: 0.474986, mean_absolute_error: 0.628987, mean_q: 0.356089




   43/5000: episode: 3, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.162 [-1.709, 2.745], loss: 0.413335, mean_absolute_error: 0.667680, mean_q: 0.325963
   54/5000: episode: 4, duration: 0.183s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.091 [0.000, 1.000], mean observation: 0.109 [-1.783, 2.779], loss: 0.419903, mean_absolute_error: 0.670422, mean_q: 0.372935
   67/5000: episode: 5, duration: 0.216s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.077 [0.000, 1.000], mean observation: 0.115 [-2.109, 3.254], loss: 0.350571, mean_absolute_error: 0.646997, mean_q: 0.467338
   76/5000: episode: 6, duration: 0.151s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0

  346/5000: episode: 34, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.156 [-1.712, 2.836], loss: 0.474636, mean_absolute_error: 0.935778, mean_q: 2.384576
  356/5000: episode: 35, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.136 [-1.592, 2.633], loss: 0.446316, mean_absolute_error: 0.978092, mean_q: 2.387925
  368/5000: episode: 36, duration: 0.199s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.083 [0.000, 1.000], mean observation: 0.118 [-1.977, 3.098], loss: 0.546934, mean_absolute_error: 1.039304, mean_q: 2.447035
  381/5000: episode: 37, duration: 0.217s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean actio

  654/5000: episode: 65, duration: 0.165s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.114 [-1.772, 2.677], loss: 0.379655, mean_absolute_error: 1.726541, mean_q: 3.907489
  664/5000: episode: 66, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.132 [-1.548, 2.481], loss: 0.284939, mean_absolute_error: 1.683274, mean_q: 4.012555
  674/5000: episode: 67, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.122 [-1.762, 2.677], loss: 0.311244, mean_absolute_error: 1.699155, mean_q: 4.064929
  683/5000: episode: 68, duration: 0.149s, episode steps: 9, steps per second: 61, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean actio

  958/5000: episode: 96, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.148 [-1.521, 2.520], loss: 0.314734, mean_absolute_error: 2.178877, mean_q: 4.947865
  968/5000: episode: 97, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.121 [-1.772, 2.688], loss: 0.321871, mean_absolute_error: 2.211872, mean_q: 5.144414
  976/5000: episode: 98, duration: 0.134s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.151 [-1.539, 2.568], loss: 0.217673, mean_absolute_error: 2.259959, mean_q: 5.238338
  985/5000: episode: 99, duration: 0.149s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action:

 1244/5000: episode: 125, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.200 [0.000, 1.000], mean observation: 0.126 [-1.375, 2.181], loss: 0.385389, mean_absolute_error: 2.855345, mean_q: 5.826180
 1252/5000: episode: 126, duration: 0.134s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.875 [0.000, 1.000], mean observation: -0.153 [-2.246, 1.400], loss: 0.344636, mean_absolute_error: 2.808215, mean_q: 5.564146
 1260/5000: episode: 127, duration: 0.133s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.140 [-2.569, 1.602], loss: 0.198096, mean_absolute_error: 2.941403, mean_q: 5.842766
 1270/5000: episode: 128, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean a

 1522/5000: episode: 154, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.889 [0.000, 1.000], mean observation: -0.140 [-2.464, 1.540], loss: 2.088202, mean_absolute_error: 3.757335, mean_q: 6.985361
 1532/5000: episode: 155, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.900 [0.000, 1.000], mean observation: -0.116 [-2.503, 1.577], loss: 1.791168, mean_absolute_error: 3.887366, mean_q: 7.255648
 1541/5000: episode: 156, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.159 [-2.838, 1.743], loss: 1.233248, mean_absolute_error: 3.853878, mean_q: 7.269069
 1552/5000: episode: 157, duration: 0.183s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean 

 1821/5000: episode: 184, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.900 [0.000, 1.000], mean observation: -0.163 [-2.643, 1.521], loss: 2.950213, mean_absolute_error: 4.600986, mean_q: 8.301473
 1831/5000: episode: 185, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.129 [-3.062, 1.971], loss: 2.004537, mean_absolute_error: 4.714692, mean_q: 8.497425
 1843/5000: episode: 186, duration: 0.200s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.917 [0.000, 1.000], mean observation: -0.114 [-3.005, 1.950], loss: 1.765394, mean_absolute_error: 4.553888, mean_q: 8.310732
 1851/5000: episode: 187, duration: 0.133s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mea

 2207/5000: episode: 213, duration: 0.266s, episode steps: 15, steps per second: 56, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.533 [0.000, 1.000], mean observation: -0.073 [-1.145, 0.630], loss: 1.752091, mean_absolute_error: 4.919221, mean_q: 9.030554
 2234/5000: episode: 214, duration: 0.488s, episode steps: 27, steps per second: 55, episode reward: 27.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.556 [0.000, 1.000], mean observation: -0.068 [-1.497, 0.626], loss: 1.454872, mean_absolute_error: 4.828894, mean_q: 8.859282
 2249/5000: episode: 215, duration: 0.247s, episode steps: 15, steps per second: 61, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.533 [0.000, 1.000], mean observation: -0.089 [-1.227, 0.813], loss: 1.645562, mean_absolute_error: 4.891934, mean_q: 8.995225
 2279/5000: episode: 216, duration: 0.535s, episode steps: 30, steps per second: 56, episode reward: 30.000, mean reward: 1.000 [1.000, 1.000], m

 3204/5000: episode: 242, duration: 0.483s, episode steps: 26, steps per second: 54, episode reward: 26.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: -0.122 [-0.680, 0.346], loss: 1.894135, mean_absolute_error: 6.303790, mean_q: 11.849040
 3256/5000: episode: 243, duration: 0.879s, episode steps: 52, steps per second: 59, episode reward: 52.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.481 [0.000, 1.000], mean observation: -0.088 [-0.727, 0.432], loss: 2.221745, mean_absolute_error: 6.299969, mean_q: 11.764969
 3308/5000: episode: 244, duration: 0.951s, episode steps: 52, steps per second: 55, episode reward: 52.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.481 [0.000, 1.000], mean observation: -0.098 [-0.720, 0.554], loss: 1.818087, mean_absolute_error: 6.306190, mean_q: 11.827723
 3374/5000: episode: 245, duration: 1.181s, episode steps: 66, steps per second: 56, episode reward: 66.000, mean reward: 1.000 [1.000, 1.000]

<keras.callbacks.History at 0x7ff6dc2e7208>

In [12]:
env.close()

In [13]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
Episode 1: reward: 43.000, steps: 43
Episode 2: reward: 92.000, steps: 92
Episode 3: reward: 117.000, steps: 117
Episode 4: reward: 40.000, steps: 40
Episode 5: reward: 67.000, steps: 67


<keras.callbacks.History at 0x7ff69c76d320>

In [14]:
env.close()