In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit

In [2]:
import numpy as np
from collections import deque
import random

In [3]:
import gym

## Build Q(s,a)

- input : state
- output : Q-values for each action

In [4]:
state_size = 4
action_size = 2
hdim = 15

In [21]:
class DQNAgent():
    
    def __init__(self, state_size=4, action_size=2, hdim=15, epsilon=1.):
        
        self.epsilon = epsilon
        self.action_size = action_size
        self.experience = deque(maxlen=2000)
        self.gamma = 0.95    # discount rate
        self.epsilon = 1.0  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        
        tf.reset_default_graph()
        self.state_ = tf.placeholder(tf.float32, shape=[state_size], name='states')
        self.target_ = tf.placeholder(tf.float32, shape=[action_size], name='targets')
        W1 = tf.get_variable('W1', dtype=tf.float32,
                             shape=[state_size, hdim], initializer=xinit() )
        b1 = tf.get_variable('b1', dtype=tf.float32,
                             shape=[hdim], initializer=xinit() )
        h = tf.nn.relu(tf.matmul(
            tf.expand_dims(self.state_, axis=0), W1) + b1)
        W2 = tf.get_variable('W2', dtype=tf.float32,
                             shape=[hdim, action_size], initializer=xinit() )
        b2 = tf.get_variable('b2', dtype=tf.float32,
                             shape=[action_size], initializer=xinit() )
        self.q_out = tf.reshape(tf.matmul(h, W2) + b2, [action_size])
        self.loss = tf.reduce_mean(tf.pow(self.q_out - self.target_, 2))
        self.update = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate).minimize(self.loss)
        
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        

    def get_action(self, s):
        if np.random.rand() <= self.epsilon:
            return np.random.randint(0, self.action_size)
        
        q_values = self.sess.run(self.q_out, feed_dict = {
            self.state_ : s.reshape([state_size])
        })
        
        return np.argmax(q_values)
    
    def remember(self, s0, a0, r0, s1, done):
        self.experience.append((s0,a0,r0,s1,done))
        
    def exp_replay(self, batch_size):
        batch = random.sample(self.experience, batch_size)
        for s,a,r,s1,done in batch:
            s = s.reshape([state_size])
            target = r
            if not done:
                q1 = self.sess.run(self.q_out, feed_dict = {
                    self.state_ : s1
                })
                q1_max = np.max(q1)
                target = r + (self.gamma * q1_max)
            q = self.sess.run(self.q_out, feed_dict = {
                self.state_ : s
            })
            q[a] = target
            self.sess.run(self.update, feed_dict = {
                self.state_ : s,
                self.target_ : q
            })
            
            if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay

In [14]:
env = gym.make('CartPole-v0')

[2017-06-05 15:28:11,324] Making new env: CartPole-v0


In [26]:
agent = DQNAgent()

In [27]:
episodes = 5000
for e in range(episodes):
    
    s = env.reset()
    
    for t in range(500):
        a = agent.get_action(s)
        s1, r, done, _ = env.step(a)
        # remember
        agent.remember(s,a,r,s1,done)
        # update state
        s = s1
        
        if done:
            print("episode: {}/{}, score: {}".format(e, episodes, t))
            break
    
    batch_size = 128
    if len(agent.experience) > batch_size :
        agent.exp_replay(batch_size)        

episode: 0/5000, score: 13
episode: 1/5000, score: 22
episode: 2/5000, score: 23
episode: 3/5000, score: 10
episode: 4/5000, score: 11
episode: 5/5000, score: 36
episode: 6/5000, score: 25
episode: 7/5000, score: 22
episode: 8/5000, score: 10
episode: 9/5000, score: 9
episode: 10/5000, score: 8
episode: 11/5000, score: 9
episode: 12/5000, score: 9
episode: 13/5000, score: 9
episode: 14/5000, score: 8
episode: 15/5000, score: 9
episode: 16/5000, score: 9
episode: 17/5000, score: 8
episode: 18/5000, score: 9
episode: 19/5000, score: 10
episode: 20/5000, score: 8
episode: 21/5000, score: 8
episode: 22/5000, score: 10
episode: 23/5000, score: 9
episode: 24/5000, score: 10
episode: 25/5000, score: 8
episode: 26/5000, score: 9
episode: 27/5000, score: 8
episode: 28/5000, score: 7
episode: 29/5000, score: 9
episode: 30/5000, score: 10
episode: 31/5000, score: 9
episode: 32/5000, score: 9
episode: 33/5000, score: 8
episode: 34/5000, score: 9
episode: 35/5000, score: 9
episode: 36/5000, score: 

episode: 290/5000, score: 86
episode: 291/5000, score: 199
episode: 292/5000, score: 199
episode: 293/5000, score: 199
episode: 294/5000, score: 38
episode: 295/5000, score: 35
episode: 296/5000, score: 42
episode: 297/5000, score: 55
episode: 298/5000, score: 128
episode: 299/5000, score: 80
episode: 300/5000, score: 52
episode: 301/5000, score: 43
episode: 302/5000, score: 53
episode: 303/5000, score: 56
episode: 304/5000, score: 71
episode: 305/5000, score: 52
episode: 306/5000, score: 44
episode: 307/5000, score: 46
episode: 308/5000, score: 41
episode: 309/5000, score: 135
episode: 310/5000, score: 42
episode: 311/5000, score: 146
episode: 312/5000, score: 63
episode: 313/5000, score: 115
episode: 314/5000, score: 79
episode: 315/5000, score: 64
episode: 316/5000, score: 35
episode: 317/5000, score: 55
episode: 318/5000, score: 66
episode: 319/5000, score: 94
episode: 320/5000, score: 117
episode: 321/5000, score: 199
episode: 322/5000, score: 43
episode: 323/5000, score: 69
episo

episode: 570/5000, score: 51
episode: 571/5000, score: 45
episode: 572/5000, score: 57
episode: 573/5000, score: 120
episode: 574/5000, score: 199
episode: 575/5000, score: 49
episode: 576/5000, score: 114
episode: 577/5000, score: 26
episode: 578/5000, score: 92
episode: 579/5000, score: 55
episode: 580/5000, score: 23
episode: 581/5000, score: 199
episode: 582/5000, score: 101
episode: 583/5000, score: 74
episode: 584/5000, score: 58
episode: 585/5000, score: 199
episode: 586/5000, score: 39
episode: 587/5000, score: 199
episode: 588/5000, score: 80
episode: 589/5000, score: 38
episode: 590/5000, score: 53
episode: 591/5000, score: 58
episode: 592/5000, score: 34
episode: 593/5000, score: 58
episode: 594/5000, score: 31
episode: 595/5000, score: 34
episode: 596/5000, score: 50
episode: 597/5000, score: 39
episode: 598/5000, score: 39
episode: 599/5000, score: 71
episode: 600/5000, score: 31
episode: 601/5000, score: 33
episode: 602/5000, score: 19
episode: 603/5000, score: 31
episode

episode: 848/5000, score: 199
episode: 849/5000, score: 199
episode: 850/5000, score: 199
episode: 851/5000, score: 199
episode: 852/5000, score: 199
episode: 853/5000, score: 199
episode: 854/5000, score: 199
episode: 855/5000, score: 199
episode: 856/5000, score: 199
episode: 857/5000, score: 199
episode: 858/5000, score: 199
episode: 859/5000, score: 199
episode: 860/5000, score: 199
episode: 861/5000, score: 199
episode: 862/5000, score: 199
episode: 863/5000, score: 199
episode: 864/5000, score: 199
episode: 865/5000, score: 199
episode: 866/5000, score: 199
episode: 867/5000, score: 199
episode: 868/5000, score: 199
episode: 869/5000, score: 199
episode: 870/5000, score: 199
episode: 871/5000, score: 199
episode: 872/5000, score: 199
episode: 873/5000, score: 199
episode: 874/5000, score: 199
episode: 875/5000, score: 199
episode: 876/5000, score: 199
episode: 877/5000, score: 199
episode: 878/5000, score: 199
episode: 879/5000, score: 199
episode: 880/5000, score: 199
episode: 8

episode: 1120/5000, score: 199
episode: 1121/5000, score: 199
episode: 1122/5000, score: 199
episode: 1123/5000, score: 199
episode: 1124/5000, score: 199
episode: 1125/5000, score: 199
episode: 1126/5000, score: 199
episode: 1127/5000, score: 199
episode: 1128/5000, score: 199
episode: 1129/5000, score: 199
episode: 1130/5000, score: 199
episode: 1131/5000, score: 199
episode: 1132/5000, score: 199
episode: 1133/5000, score: 199
episode: 1134/5000, score: 169
episode: 1135/5000, score: 171
episode: 1136/5000, score: 159
episode: 1137/5000, score: 180
episode: 1138/5000, score: 143
episode: 1139/5000, score: 177
episode: 1140/5000, score: 199
episode: 1141/5000, score: 199
episode: 1142/5000, score: 199
episode: 1143/5000, score: 199
episode: 1144/5000, score: 198
episode: 1145/5000, score: 172
episode: 1146/5000, score: 193
episode: 1147/5000, score: 199
episode: 1148/5000, score: 184
episode: 1149/5000, score: 199
episode: 1150/5000, score: 182
episode: 1151/5000, score: 199
episode:

KeyboardInterrupt: 

In [30]:
avg_reward = 0

for i in range(100):
    
    s = env.reset()
    
    for t in range(500):
        a = agent.get_action(s)
        s1, r, done, _ = env.step(a)
        # remember
        #agent.remember(s,a,r,s1,done)
        # update state
        s = s1
        if done:
            avg_reward += t
            break
print(avg_reward/100)

198.72
