In [1]:
import gym
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
from LinearEpsilonExplorer import LinearEpsilonExplorer
from ReplayMemory import ReplayMemory
%matplotlib inline

  from ._conv import register_converters as _register_converters


In [3]:
class DQNAgent:
    
    def __init__(self, 
                 sess, 
                 input_shape, 
                 action_num,
                 lr=0.00025,
                 gamma=0.99,
                 explorer=LinearEpsilonExplorer(1, 0.1, 100000),
                 minibatch=32,
                 memory_size=1000000,
                 target_update_interval=10000,
                 train_after=10000):
        
        self.sess = sess
        self.explorer = explorer
        self.minibatch = minibatch
        self.target_update_interval = target_update_interval
        self.train_after = train_after
        self.gamma = gamma
        self.input_shape = input_shape
        self.action_num = action_num
        
        self.replay_memory = ReplayMemory(memory_size)
        self.num_action_taken = 0
        
        self.X_Q = tf.placeholder(tf.float32, [None] + [self.input_shape])
        self.X_t = tf.placeholder(tf.float32, [None] + [self.input_shape])
        self.Q_network = self._build_network("Q_network", self.X_Q)
        self.target_network = self._build_network("target_network", self.X_t)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        
        with tf.variable_scope("optimizer"):
            self.actions = tf.placeholder(tf.int32, [None], name="actions")
            # Q estimate
            actions_one_hot = tf.one_hot(self.actions, self.action_num)
            Q_pred = tf.reduce_sum(tf.multiply(self.Q_network, actions_one_hot), axis=1)
            # td_target
            self.td_target = tf.placeholder(tf.float32, [None])
            # loss
            self.loss = tf.losses.huber_loss(self.td_target, Q_pred)
            self.train_step = self.optimizer.minimize(self.loss)
            
        self.eval_param = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Q_network")
        self.target_param = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="target_network")
    
    def _build_network(self, scope_name, X):
        with tf.variable_scope(scope_name):
            conv1 = tf.layers.conv2d(X, 32, [8,8], [4,4], "same", activation=tf.nn.relu)
            conv2 = tf.layers.conv2d(conv1, 64, [4,4], [2,2], "same", activation=tf.nn.relu)
            conv3 = tf.layers.conv2d(conv2, 64, [3,3], [1,1], "same", activation=tf.nn.relu)
            flat = tf.layers.flatten(conv3)
            fc = tf.layers.dense(flat, 512, activation=tf.nn.relu)
            out = tf.layers.dense(fc, self.action_num)
        return out
    
    def act(self, state):
        # choose action given state
        # follow a linearly decay epsilon greedy policy
        if self.num_action_taken >= self.train_after:
            if self.explorer.explore(self.num_action_taken - self.train_after):
                action = self.explorer.choose_random_action(self.action_num)
            else:
                env_history = state
                env_history = np.reshape(env_history, [1]+[self.input_shape])
                Q_values = self.sess.run(self.Q_network, feed_dict={self.X_Q : env_history})
                action = np.argmax(Q_values[0])
        else:
            action = self.explorer.choose_random_action(self.action_num)
            
        self.num_action_taken += 1
        return action
    
    def observe(self, pre_state, action, reward, post_state, done):
        # store transition in replay memory
        self.replay_memory.append(pre_state, action, reward, post_state, done)
        
    def process(self, observation):
        img = Image.fromarray(observation)
        img = img.resize((84,84))
        img = img.convert("F")
        img = np.array(img)
        img = (img - 127) / 127
        return img
        
    def train(self):
        loss = 0
        if self.num_action_taken >= self.train_after:
            # retrieve data
            pre_states, actions, rewards, post_states, terminals = self.replay_memory.sample(self.minibatch)
            # Double DQN uses Q_network to choose action for post state
            # and then use target network to evaluate that policy
            Q_eval = self.sess.run(self.Q_network, feed_dict={self.X_Q:post_states})
            best_action = np.argmax(Q_eval, axis=1)
            # create one hot representation for action
            best_action_oh = np.zeros((best_action.size, self.action_num))
            best_action_oh[np.arange(best_action.size), best_action] = 1
            # evaluate through target_network
            Q_target = self.sess.run(self.target_network, feed_dict={self.X_t:post_states}) * best_action_oh
            Q_target = np.sum(Q_target, axis=1)
            y_batch = np.array(rewards) + self.gamma * Q_target * (1 - np.array(terminals))
            _, loss = self.sess.run([self.train_step, self.loss], feed_dict={self.X_Q:pre_states, self.actions:actions, self.td_target:y_batch})
        
            if self.num_action_taken % self.target_update_interval == 0:
                self._update_target_net()
        
        return loss
    
    def _update_target_net(self):
        ops = [tf.assign(dest_var, src_var) for dest_var, src_var in zip(self.target_param, self.eval_param)]
        sess.run(ops)