In [4]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam


In [5]:
EPISODES = 1000

In [21]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen = 2000)
        self.gamma = 0.95 # discount rate 
        self.epsilon = 1.0 # exploration rate 
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001 # learning rate of neural network
        self.model = self._build_model()
        
    def _build_model(self):
        model = Sequential()
        model.add(Dense(24, input_dim = self.state_size, activation = "relu"))
        model.add(Dense(24, activation = "relu"))
        model.add(Dense(self.action_size, activation = "linear")) #may update to better output layer
        model.compile(loss = "mse", optimizer = Adam(lr = self.learning_rate))
        return model
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        
    def act(self, state):
        if random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state)
        #print(act_values)
        # act_values[0] is the  Q(s, a) or Q values, the value represents the 
        # expected value of your maximal score at the end of the game if you were at that 
        # state and played that action 
        # np.amax() determines the value, np.argmax() indexes to the action that generates the maximum value
        return np.argmax(act_values[0]) 
    
    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done: 
                target = target + self.gamma * np.amax(self.model.predict(next_state)[0])
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs = 1, verbose = 0)
        if self.epsilon_min < self.epsilon:
            self.epsilon *= self.epsilon_decay
            
    

                

In [22]:
def main():
    env = gym.make("CartPole-v0")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    done = False
    batch_size = 32
    
    for e in range(EPISODES):
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        for time in range(500):
            env.render()
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            reward = reward if not done else -10
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print("Episode: {} / {}, score: {}, e: {:.2}".format(e, EPISODES, time, agent.epsilon) )
                break;
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)

In [23]:
main()

[33mWARN: gym.spaces.Box autodetected dtype as <type 'numpy.float32'>. Please provide explicit dtype.[0m
Episode: 0 / 1000, score: 34, e: 0.99
[[7.503326 6.807575]]
[[41.743156 54.60127 ]]
Episode: 1 / 1000, score: 21, e: 0.89
[[133.3578  183.34918]]
[[164.81442 230.59453]]
Episode: 2 / 1000, score: 17, e: 0.82
[[348.2873  500.19296]]
[[501.58896 741.24817]]
[[582.231  875.7204]]
[[ 901.3681 1310.4584]]
Episode: 3 / 1000, score: 8, e: 0.79
[[425.83975 588.2005 ]]
[[535.7032  677.23505]]
[[668.79626 851.16144]]
[[717.1558 913.5794]]
[[746.88086 935.7315 ]]
Episode: 4 / 1000, score: 20, e: 0.71
[[564.86975 690.1547 ]]
[[632.4439  752.36224]]
[[695.1755  824.01117]]
Episode: 5 / 1000, score: 19, e: 0.65
[[546.6185 614.1067]]
[[588.68195 655.2637 ]]
[[532.56354 584.17786]]
Episode: 6 / 1000, score: 16, e: 0.6
[[429.55252 454.99402]]
[[443.78943 471.37015]]
[[452.34058 480.43945]]
[[460.96857 490.40622]]
[[389.10886 408.34262]]
[[342.82413 351.06915]]
[[301.82422 305.0021 ]]
[[274.66577 2

[[3.268994 4.224449]]
[[3.0534713 2.534901 ]]
[[4.214996  4.7000036]]
[[2.5129144 1.5382316]]
[[2.4960139 1.0482378]]
[[2.5934598 1.3323245]]
[[2.3029256 2.1496096]]
[[0.58827096 0.7584805 ]]
[[-0.25960803 -1.0408394 ]]
[[-1.4446952 -2.153521 ]]
[[-0.85139817 -0.8592221 ]]
[[1.1007003 1.5582494]]
[[-1.7923752 -3.3907986]]
[[-0.30138722 -0.59641194]]
[[1.6881627 2.3008814]]
[[0.20912325 0.06825364]]
[[1.3237232 1.6075276]]
[[0.54087013 1.5944908 ]]
[[-0.56663007  0.4935388 ]]
[[ 0.13398084 -0.49749216]]
[[0.12895791 0.00893022]]
[[-0.7342483 -1.3204726]]
[[1.1456171 1.4711437]]
[[-0.07570969 -0.06412958]]
[[-1.2864357 -1.8993628]]
[[-1.1037508  -0.98056066]]
[[-0.9984869 -2.693931 ]]
[[-2.3766782 -3.945067 ]]
[[-0.6765395 -0.1480133]]
[[-3.6667607 -4.081829 ]]
[[-2.888361 -3.436128]]
[[-1.0235431  -0.20668676]]
[[-5.064388 -6.972949]]
Episode: 34 / 1000, score: 38, e: 0.1
[[10.693998 12.412509]]
[[11.907173 14.253602]]
[[13.522705 16.040247]]
[[12.109328 13.071547]]
[[10.999641 11.10051

[[14.020017 14.393213]]
[[8.778207 7.570455]]
[[12.085085 11.755424]]
[[12.837414  13.9197445]]
[[10.474175 10.143283]]
[[12.753305 13.855026]]
[[9.1118965 8.257755 ]]
Episode: 40 / 1000, score: 121, e: 0.1
[[23.096373 22.270824]]
[[18.555235 22.60056 ]]
[[23.36733  22.327564]]
[[25.856882 27.621368]]
[[27.326357 27.141342]]
[[22.168495 25.374722]]
[[22.68054  22.015966]]
[[23.491064 25.991667]]
[[25.188896 25.238115]]
[[23.512922 20.711357]]
[[24.69764  22.425255]]
[[16.749369 20.85499 ]]
[[24.397476 28.167744]]
[[25.613726 25.479265]]
[[25.841385 29.389282]]
[[26.744804 28.423952]]
[[28.185047 26.079062]]
[[28.512028 28.382677]]
[[24.196558 26.1455  ]]
[[27.13523  27.307251]]
[[27.630978 26.244942]]
[[27.792767 29.3385  ]]
[[29.243292 27.087193]]
[[31.17089  32.997776]]
[[32.932804 31.296251]]
[[29.677649 33.3757  ]]
[[31.81756  31.526062]]
[[30.276886 33.38652 ]]
[[31.058327 31.396278]]
[[30.559153 28.47088 ]]
[[31.956928 35.730186]]
[[34.768536 35.08151 ]]
[[33.633034 31.286003]]
[

[[27.244604 29.319225]]
[[31.201809 28.690197]]
[[27.626665 30.44039 ]]
[[28.505701 26.325106]]
[[25.260834 29.03066 ]]
[[31.127542 31.313622]]
[[31.040651 29.581991]]
[[28.974516 29.747543]]
[[25.11615  21.821667]]
[[23.34704  23.347307]]
[[25.806757 23.345818]]
[[25.5582   26.309523]]
[[25.892855 24.294296]]
[[17.421495 21.087637]]
[[23.994951 25.926521]]
[[26.238586 24.181757]]
[[24.990034 27.148817]]
[[25.162859 24.682003]]
[[24.470608 27.826982]]
[[28.223589 26.53466 ]]
[[23.062943 26.876265]]
[[25.048939 24.072248]]
[[18.555283 21.390781]]
[[21.982807 23.156143]]
[[25.038239 22.98995 ]]
[[21.959171 23.80909 ]]
[[26.648718 23.859308]]
[[27.920378 25.61504 ]]
[[25.076944 24.933712]]
[[23.953703 27.531729]]
[[23.24095  23.047855]]
[[12.0458765 17.382341 ]]
[[16.091238 21.049814]]
[[20.816849 22.697088]]
[[16.061842 19.417786]]
[[16.211552 16.297132]]
[[15.824293 12.715307]]
[[12.164357 13.593384]]
[[15.360381 12.420861]]
[[15.4809885 17.551737 ]]
[[19.670778 18.383205]]
[[5.189227 9

[[-5.8582015 -3.6788137]]
[[-2.1632602 -1.392357 ]]
[[ 0.5955827 -1.5241604]]
[[-1.2664179 -1.7068346]]
[[-9.993797 -4.804294]]
[[-8.959025 -8.315615]]
Episode: 49 / 1000, score: 77, e: 0.1
[[30.911606 29.177563]]
[[24.258778 27.719038]]
[[27.232319 27.809908]]
[[30.270012 26.706625]]
[[26.351158 27.418844]]
[[30.228355 29.844303]]
[[31.627766 30.757854]]
[[26.154634 28.180508]]
[[28.511585 25.352432]]
[[24.314987 24.56824 ]]
[[26.633984 22.391579]]
[[24.941786 24.806469]]
[[24.912823 28.791542]]
[[26.864475 27.805403]]
[[27.48149  24.808182]]
[[25.949644 26.386002]]
[[25.564552 23.52031 ]]
[[22.757153 24.773085]]
[[24.86261  22.105349]]
[[20.160156 21.190756]]
[[22.560165 24.08957 ]]
[[24.417326 22.189535]]
[[19.93946  23.694519]]
[[24.069685 26.004505]]
[[22.2143   25.303873]]
[[25.543694 25.315454]]
[[23.461924 25.969309]]
[[23.944061 24.586891]]
[[23.857918 22.01116 ]]
[[21.747093 22.534977]]
[[23.208763 20.753464]]
[[24.417234 26.104105]]
[[24.644854 22.547102]]
[[20.837093 21.775

KeyboardInterrupt: 