In [None]:
# OpenGym CartPole-v0
# -------------------
#
# This code demonstrates use a full DQN implementation
# to solve OpenGym CartPole-v0 problem.
#
# Made as part of blog series Let's make a DQN, available at: 
# https://jaromiru.com/2016/09/27/lets-make-a-dqn-theory/
# 
# author: Jaromir Janisch, 2016

import random, numpy, math, gym, sys
from keras import backend as K

import tensorflow as tf

#----------
HUBER_LOSS_DELTA = 1.0
LEARNING_RATE = 0.00025

#----------
def huber_loss(y_true, y_pred):
    err = y_true - y_pred

    cond = K.abs(err) < HUBER_LOSS_DELTA
    L2 = 0.5 * K.square(err)
    L1 = HUBER_LOSS_DELTA * (K.abs(err) - 0.5 * HUBER_LOSS_DELTA)

    loss = tf.where(cond, L2, L1)   # Keras does not cover where function in tensorflow :-(

    return K.mean(loss)

#-------------------- BRAIN ---------------------------
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import *

class Brain:
    def __init__(self, stateCnt, actionCnt):
        self.stateCnt = stateCnt
        self.actionCnt = actionCnt

        self.model = self._createModel()
        self.model_ = self._createModel() 

    def _createModel(self):
        model = Sequential()

        model.add(Dense(units=64, activation='relu', input_dim=stateCnt))
        model.add(Dense(units=actionCnt, activation='linear'))

        opt = RMSprop(lr=LEARNING_RATE)
        model.compile(loss=huber_loss, optimizer=opt)

        return model

    def train(self, x, y, epochs=1, verbose=0):
        self.model.fit(x, y, batch_size=64, epochs=epochs, verbose=verbose)

    def predict(self, s, target=False):
        if target:
            return self.model_.predict(s)
        else:
            return self.model.predict(s)

    def predictOne(self, s, target=False):
        return self.predict(s.reshape(1, self.stateCnt), target=target).flatten()

    def updateTargetModel(self):
        self.model_.set_weights(self.model.get_weights())

#-------------------- MEMORY --------------------------
class Memory:   # stored as ( s, a, r, s_ )
    samples = []

    def __init__(self, capacity):
        self.capacity = capacity

    def add(self, sample):
        self.samples.append(sample)        

        if len(self.samples) > self.capacity:
            self.samples.pop(0)

    def sample(self, n):
        n = min(n, len(self.samples))
        return random.sample(self.samples, n)

    def isFull(self):
        return len(self.samples) >= self.capacity

#-------------------- AGENT ---------------------------
MEMORY_CAPACITY = 100000
BATCH_SIZE = 64

GAMMA = 0.99

MAX_EPSILON = 1
MIN_EPSILON = 0.01
LAMBDA = 0.001      # speed of decay

UPDATE_TARGET_FREQUENCY = 1000

class Agent:
    steps = 0
    epsilon = MAX_EPSILON

    def __init__(self, stateCnt, actionCnt):
        self.stateCnt = stateCnt
        self.actionCnt = actionCnt

        self.brain = Brain(stateCnt, actionCnt)
        self.memory = Memory(MEMORY_CAPACITY)
        
    def act(self, s):
        if random.random() < self.epsilon:
            return random.randint(0, self.actionCnt-1)
        else:
            return numpy.argmax(self.brain.predictOne(s))

    def observe(self, sample):  # in (s, a, r, s_) format
        self.memory.add(sample)        

        if self.steps % UPDATE_TARGET_FREQUENCY == 0:
            self.brain.updateTargetModel()

        # debug the Q function in poin S
        if self.steps % 100 == 0:
            S = numpy.array([-0.01335408, -0.04600273, -0.00677248, 0.01517507])
            pred = agent.brain.predictOne(S)
            print(pred[0])
            sys.stdout.flush()

        # slowly decrease Epsilon based on our eperience
        self.steps += 1
        self.epsilon = MIN_EPSILON + (MAX_EPSILON - MIN_EPSILON) * math.exp(-LAMBDA * self.steps)

    def replay(self):    
        batch = self.memory.sample(BATCH_SIZE)
        batchLen = len(batch)

        no_state = numpy.zeros(self.stateCnt)

        states = numpy.array([ o[0] for o in batch ])
        states_ = numpy.array([ (no_state if o[3] is None else o[3]) for o in batch ])

        p = self.brain.predict(states)
        p_ = self.brain.predict(states_, target=True)

        x = numpy.zeros((batchLen, self.stateCnt))
        y = numpy.zeros((batchLen, self.actionCnt))
        
        for i in range(batchLen):
            o = batch[i]
            s = o[0]; a = o[1]; r = o[2]; s_ = o[3]
            
            t = p[i]
            if s_ is None:
                t[a] = r
            else:
                t[a] = r + GAMMA * numpy.amax(p_[i])

            x[i] = s
            y[i] = t

        self.brain.train(x, y)


class RandomAgent:
    memory = Memory(MEMORY_CAPACITY)

    def __init__(self, actionCnt):
        self.actionCnt = actionCnt

    def act(self, s):
        return random.randint(0, self.actionCnt-1)

    def observe(self, sample):  # in (s, a, r, s_) format
        self.memory.add(sample)

    def replay(self):
        pass

#-------------------- ENVIRONMENT ---------------------
class Environment:
    def __init__(self, problem):
        self.problem = problem
        self.env = gym.make(problem)

    def run(self, agent):
        s = self.env.reset()
        R = 0 

        while True:            
            # self.env.render()

            a = agent.act(s)

            s_, r, done, info = self.env.step(a)

            if done: # terminal state
                s_ = None

            agent.observe( (s, a, r, s_) )
            agent.replay()            

            s = s_
            R += r

            if done:
                break

        # print("Total reward:", R)

#-------------------- MAIN ----------------------------
PROBLEM = 'CartPole-v0'
env = Environment(PROBLEM)

stateCnt  = env.env.observation_space.shape[0]
actionCnt = env.env.action_space.n

agent = Agent(stateCnt, actionCnt)
randomAgent = RandomAgent(actionCnt)

try:
    while randomAgent.memory.isFull() == False:
        env.run(randomAgent)

    agent.memory.samples = randomAgent.memory.samples
    randomAgent = None

    while True:
        env.run(agent)
finally:
    agent.brain.model.save("cartpole-dqn.h5")

Using TensorFlow backend.


-0.0170133
0.108469
0.240536
0.39619
0.564973
0.745644
0.907233
0.973099
0.96527
0.967004
0.960812
1.25502
1.53974
1.81557
1.91316
1.95481
1.96888
1.95748
1.96086
1.94815
1.94641
2.3515
2.71563
2.87623
2.92653
2.93661
2.94927
2.93944
2.94404
2.93975
2.92804
3.42182
3.80639
3.88379
3.91539
3.93492
3.92944
3.92211
3.93521
3.93128
3.93585
4.50206
4.82204
4.85374
4.90179
4.90858
4.90304
4.89574
4.90439
4.89755
4.89936
5.5237
5.7931
5.85537
5.84465
5.88209
5.86333
5.87284
5.87835
5.87634
5.8449
6.51184
6.77821
6.8158
6.79091
6.82424
6.81813
6.81679
6.83206
6.82272
6.80344
7.50766
7.72549
7.77097
7.74504
7.74381
7.78496
7.78001
7.75528
7.77314
7.77144
8.49415
8.68424
8.68848
8.71443
8.71063
8.69527
8.71659
8.70741
8.71043
8.7346
9.50004
9.65753
9.67039
9.65848
9.67336
9.66382
9.6769
9.66143
9.70172
9.68764
10.4724
10.597
10.6271
10.6248
10.6058
10.6375
10.6389
10.6064
10.6301
10.6142
11.423
11.5357
11.5248
11.53
11.5555
11.5155
11.5225
11.539
11.536
11.5204
12.3344
12.4223
12.422
12.4126
12.

70.1351
70.1311
70.122
70.1314
70.3946
70.4582
70.4684
70.4571
70.4443
70.4699
70.4565
70.4776
70.4739
70.438
70.7406
70.7799
70.7764
70.773
70.7861
70.7333
70.7948
70.7886
70.7975
70.7828
71.0957
71.0927
71.095
71.1265
71.1061
71.126
71.0684
71.1014
71.1015
71.1493
71.4222
71.446
71.4633
71.413
71.4703
71.4539
71.4569
71.4435
71.4639
71.4501
71.7383
71.7601
71.7644
71.7737
71.7891
71.7565
71.7189
71.7208
71.784
71.7651
72.0624
72.0416
72.0359
72.0416
72.0456
72.0197
72.0641
72.0368
72.0504
72.0534
72.3321
72.3473
72.3302
72.3385
72.3301
72.3542
72.3407
72.3359
72.3147
72.3625
72.6166
72.6539
72.6308
72.6589
72.64
72.6591
72.6566
72.6724
72.653
72.6249
72.9067
72.939
72.9147
72.9039
72.9379
72.8847
72.9327
72.8889
72.9363
72.9077
73.1761
73.1861
73.1847
73.2297
73.1742
73.1851
73.1934
73.1627
73.1828
73.1549
73.4268
73.4855
73.4059
73.4571
73.4397
73.4331
73.4288
73.421
73.4304
73.3991
73.6817
73.6829
73.6562
73.6671
73.6773
73.6728
73.6844
73.644
73.6773
73.6535
73.937
73.934
73.9295


95.4814
95.4746
95.4731
95.4485
95.5014
95.58
95.5885
95.595
95.577
95.5853
95.5773
95.5712
95.6347
95.5706
95.5805
95.6728
95.6465
95.642
95.6434
95.6884
95.6489
95.6867
95.6926
95.7066
95.6834
95.7676
95.7759
95.761
95.7579
95.7999
95.7884
95.7952
95.7582
95.7775
95.7663
95.8445
95.8607
95.8604
95.8489
95.8651
95.8861
95.8815
95.8952
95.8827
95.8809
95.9793
95.969
95.9478
95.9152
95.9795
95.9681
95.9829
95.9725
95.9686
95.9515
96.0198
96.0551
96.0268
96.0313
96.0266
96.0255
96.0336
96.0657
96.0548
96.0341
96.1115
96.1317
96.1283
96.087
96.1395
96.1208
96.0709
96.0931
96.1105
96.1225
96.2016
96.2183
96.168
96.1675
96.1767
96.2045
96.1865
96.1784
96.2027
96.18
96.256
96.2755
96.3061
96.2594
96.3092
96.2613
96.2817
96.2948
96.2775
96.2473
96.3233
96.3488
96.3806
96.3658
96.3606
96.3242
96.3411
96.3849
96.3881
96.3656
96.4729
96.4765
96.4679
96.5053
96.4717
96.4804
96.4807
96.4623
96.4868
96.4817
96.5701
96.5927
96.5899
96.6121
96.5815
96.6406
96.5875
96.5866
96.5963
96.6172
96.6984
96.7

100.301
100.398
100.369
100.376
100.417
100.378
100.4
100.43
100.418
100.428
100.425
100.396
100.412
100.42
100.472
100.446
100.439
100.459
100.472
100.433
100.434
100.417
100.462
100.435
100.465
100.493
100.483
100.499
100.477
100.49
100.457
100.462
100.469
100.45
100.52
100.484
100.5
100.479
100.484
100.476
100.548
100.487
100.472
100.501
100.516
100.542
100.512
100.537
100.474
100.512
100.539
100.536
100.517
100.538
100.544
100.549
100.577
100.504
100.545
100.572
100.56
100.578
100.582
100.576
100.635
100.605
100.594
100.607
100.597
100.602
100.62
100.638
100.623
100.575
100.644
100.592
100.569
100.603
100.586
100.611
100.627
100.582
100.582
100.61
100.686
100.638
100.664
100.634
100.659
100.596
100.635
100.668
100.611
100.615
100.636
100.647
100.638
100.646
100.665
100.636
100.648
100.665
100.62
100.627
100.662
100.658
100.638
100.681
100.652
100.616
100.68
100.655
100.685
100.622
100.674
100.631
100.655
100.616
100.67
100.64
100.627
100.641
100.662
100.635
100.69
100.607
100.675
1

100.873
100.874
100.831
100.832
100.827
100.844
100.85
100.836
100.851
100.881
100.874
100.874
100.819
100.861
100.856
100.862
100.853
100.899
100.854
100.85
100.863
100.856
100.854
100.846
100.843
100.904
100.862
100.835
100.861
100.845
100.86
100.852
100.855
100.858
100.863
100.895
100.836
100.818
100.795
100.852
100.812
100.834
100.858
100.831
100.818
100.81
100.782
100.739
100.806
100.812
100.813
100.786
100.773
100.81
100.78
100.821
100.765
100.767
100.745
100.763
100.71
100.741
100.786
100.762
100.786
100.759
100.746
100.743
100.763
100.77
100.74
100.773
100.786
100.769
100.76
100.745
100.76
100.767
100.785
100.786
100.776
100.766
100.766
100.774
100.77
100.799
100.804
100.811
100.833
100.789
100.832
100.792
100.817
100.82
100.825
100.796
100.825
100.828
100.834
100.839
100.874
100.807
100.841
100.857
100.829
100.83
100.854
100.918
100.88
100.913
100.909
100.864
100.887
100.885
100.881
100.87
100.871
100.904
100.902
100.895
100.902
100.89
100.889
100.909
100.942
100.929
100.89
10

101.307
101.312
101.341
101.366
101.284
101.293
101.349
101.324
101.287
101.351
101.334
101.315
101.292
101.326
101.294
101.308
101.269
101.276
101.277
101.237
101.27
101.27
101.284
101.282
101.286
101.271
101.244
101.274
101.264
101.306
101.267
101.291
101.275
101.293
101.279
101.274
101.276
101.262
101.244
101.277
101.252
101.272
101.241
101.219
101.253
101.285
101.275
101.265
101.27
101.276
101.239
101.248
101.273
101.272
101.283
101.261
101.246
101.299
101.262
101.302
101.29
101.275
101.296
101.269
101.283
101.295
101.264
101.303
101.278
101.291
101.274
101.281
101.293
101.276
101.277
101.309
101.273
101.319
101.303
101.309
101.314
101.327
101.295
101.287
101.307
101.318
101.318
101.295
101.319
101.277
101.317
101.288
101.317
101.311
101.244
101.3
101.253
101.289
101.292
101.263
101.279
101.294
101.309
101.294
101.3
101.287
101.296
101.311
101.333
101.269
101.324
101.292
101.32
101.276
101.336
101.334
101.334
101.337
101.295
101.324
101.342
101.371
101.308
101.35
101.335
101.328
10

104.425
104.413
104.567
104.583
104.598
104.594
104.588
104.635
104.595
104.599
104.538
104.609
104.723
104.757
104.766
104.787
104.798
104.806
104.797
104.808
104.805
104.78
104.916
104.919
104.946
104.942
104.965
104.957
104.948
104.981
104.948
104.951
105.059
105.081
105.093
105.082
105.083
105.113
105.106
105.142
105.14
105.07
105.196
105.236
105.177
105.221
105.211
105.204
105.231
105.237
105.233
105.239
105.326
105.361
105.376
105.393
105.405
105.411
105.387
105.407
105.403
105.391
105.514
105.514
105.505
105.518
105.52
105.533
105.521
105.552
105.536
105.497
105.586
105.59
105.619
105.607
105.593
105.561
105.626
105.646
105.554
105.648
105.747
105.775
105.743
105.786
105.806
105.81
105.773
105.767
105.799
105.772
105.871
105.874
105.889
105.887
105.862
105.934
105.959
105.931
105.908
105.919
105.943
105.956
105.968
106.016
105.991
106.004
106.017
106.019
106.02
106.051
106.082
106.111
106.066
106.134
106.102
106.14
106.156
106.11
106.152
106.143
106.152
106.085
106.145
106.216
1

104.362
104.373
104.337
104.363
104.345
104.325
104.316
104.325
104.328
104.315
104.293
104.283
104.282
104.297
104.304
104.274
104.266
104.28
104.267
104.29
104.263
104.279
104.262
104.258
104.225
104.244
104.247
104.241
104.232
104.281
104.232
104.238
104.193
104.218
104.193
104.156
104.209
104.209
104.187
104.216
104.176
104.172
104.167
104.157
104.152
104.175
104.151
104.136
104.167
104.153
104.167
104.169
104.194
104.179
104.148
104.178
104.137
104.119
104.132
104.144
104.162
104.143
104.104
104.138
104.16
104.112
104.154
104.149
104.127
104.124
104.142
104.114
104.174
104.094
104.14
104.134
104.095
104.11
104.114
104.101
104.103
104.099
104.069
104.082
104.093
104.046
104.074
104.032
104.096
104.052
104.067
104.049
104.059
104.017
104.014
104.039
104.075
104.026
103.982
104.001
104.019
104.055
104.052
103.995
103.998
104.0
103.99
103.961
103.965
103.987
104.024
104.03
103.989
103.977
103.983
103.988
103.991
103.991
103.982
103.982
103.94
103.975
104.013
104.006
104.015
104.01
104