In [None]:
import numpy as np
import tensorflow as tf
from sonnet.python.modules.basic import Linear
from sonnet.python.modules.base import AbstractModule

LEARNING_RATE = 1e-4
_EPSILON = 1e-6 # avoid nan
ENTROPY_BETA = 1
GAMMA = 0.9 # discount factor
VALUE_BETA = 0.5

In [None]:
def swich(tensor):
    return tensor * tf.nn.sigmoid(tensor + _EPSILON)

# shared neural network
def _build_shared_network(inputs):
    # inputs [batch_size, state_size]
    network = Linear(32, 'input_layer')(inputs)
    return swich(network)

# build approximate neural network
def _build_approximate_network(inputs, action_size):
    shared_network = _build_shared_network(inputs)
    
    policy = Linear(32, 'policy_input')(shared_network)
    policy = swich(policy)
    policy = Linear(action_size, 'policy_output')(policy)
    policy = tf.nn.softmax(policy + _EPSILON) # avoid nan   
    
    value = Linear(32, 'value_input')(shared_network)
    value = swich(value)
    value = Linear(1, 'value_output')(value)
    return policy, value

class simple_approximate_network(AbstractModule):
    def __init__(self, name):
        super().__init__(name=name)
    
    def _build(self, inputs, action_size):
        return _build_approximate_network(inputs, action_size)
        
# batch gather function from https://github.com/deepmind/dnc/blob/master/util.py
def _batch_gather(values, indices):
    """Returns batched `tf.gather` for every row in the input."""
    with tf.name_scope('batch_gather', values=[values, indices]):
        unpacked = zip(tf.unstack(values), tf.unstack(indices))
        result = [tf.gather(value, index) for value, index in unpacked]
        return tf.stack(result)

In [None]:
# global network for buffer weights and calculate gardients
class Access(object):
    def __init__(self, state_size, action_size, name='access'):
        #variable_scope for more clear graph, not necessary
        with tf.variable_scope(name):                   
            # placeholder for state and next state or you may like call it observation
            self.inputs = tf.placeholder(tf.float32, [None, state_size], 'inputs')     
            self.network = simple_approximate_network('global_network')
            self.policy, self.value = self.network(self.inputs, action_size)
            
        self.optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
        #self.optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, 0.99, name='optimizer')
        
    def get_trainable_variables(self):
        return self.network.get_variables()
    
    
# local network for advantage actor-critic which are also know as A2C
class ACNet(object):
    def __init__(self, Access, state_size, action_size, name):
        self.Access = Access
        self.state_size = state_size
        self.action_size = action_size
        # action space, we assume that action space is range(0 to action_size-1)
        self.action_space = np.arange(action_size, dtype=np.int32)
        
        #variable_scope local graph, necessary
        with tf.variable_scope(name):
            # placeholder for state and next state or you may like call it observation
            self.inputs = tf.placeholder(tf.float32, [None, state_size], 'inputs')   
            self.action = tf.placeholder(tf.int32, [None], 'action')
            # n-step reward and discounted n next step value
            self.target = tf.placeholder(tf.float32, [None, 1], 'target')
            
            self.network = simple_approximate_network('ACNet')
            self.policy, self.value = self.network(self.inputs, action_size)
            
            self._build_loss_function()
            self.update_local, self.update_access = self._build_update()         
        
    def _build_loss_function(self):
        self.advantage = self.target - self.value
        # value loss
        self.value_loss = tf.reduce_mean(tf.square(self.advantage))
    
        # policy loss
        # get the stochastic policy action probability
        #policy_action = _batch_gather(self.policy, self.action)
        action_onehot = tf.one_hot(self.action, self.action_size)
        policy_action = tf.reduce_sum(self.policy * action_onehot, axis=1, keep_dims=True)
        print (action_onehot)
        print (policy_action)

        log_policy_action = tf.log(policy_action + _EPSILON)
        # no grad pass through advantage in actor network 
        policy_loss = tf.stop_gradient(self.advantage*0.1) * tf.expand_dims(log_policy_action, axis=1)
        # entropy loss
        entropy_loss = tf.reduce_mean(self.policy * tf.log(self.policy + _EPSILON), axis=1, keep_dims=True)
        self.policy_loss = tf.reduce_mean(policy_loss + ENTROPY_BETA * entropy_loss)
        
        self.total_loss = VALUE_BETA * self.value_loss + self.policy_loss
        # adjust some params
        self.a_policy_loss = tf.reduce_mean(policy_loss)
        self.a_entropy_loss = tf.reduce_mean(entropy_loss)
        self.a_value_loss = self.value_loss

    def _build_update(self):
        global_params = list(self.Access.get_trainable_variables())
        local_params = list(self.get_trainable_variables())
        
        # update local network weights
        zip_list = []
        for g,l in zip(global_params, local_params):
            zip_list.append(l.assign(g))
        
        # update global network gradients
        local_grads = tf.gradients(self.total_loss, local_params)
        apply_gradients = self.Access.optimizer.apply_gradients(zip(local_grads, global_params))
        return zip_list, apply_gradients    
    
    def get_trainable_variables(self):
        return self.network.get_variables()
    
    def choose_action(self, SESS, state):  # run by a local
        policy = SESS.run(self.policy, {self.inputs: np.expand_dims(state, axis=0)})
        policy = np.squeeze(policy)
        action = np.random.choice(self.action_space, 1, p=policy)
        return action

In [None]:
import gym

GAME = 'CartPole-v0'
env = gym.make(GAME)
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

MAX_EPISODES = 100000
T_MAX = 10

In [None]:
tf.reset_default_graph()
SESS = tf.Session()
with tf.device("/cpu:0"):
    master = Access(state_size, action_size)
    worker = ACNet(master, state_size, action_size, 'W0')
    SESS.run(tf.global_variables_initializer())
    
    episode_score_list = []
    episode = 0
    while episode < MAX_EPISODES:
        if episode < 700:
            ENTROPY_BETA = 10
        else:
            ENTROPY_BETA = 1
        
        t_start = t = 1
        state = env.reset()
        
        buffer_state = []
        buffer_reward = []
        buffer_next_state = []
        episode_score = 0
        
        while True:
            SESS.run(worker.update_local)
            action = worker.choose_action(SESS, state)[0]
            next_state, reward, done, info = env.step(action)
            episode_score += reward
            
            buffer_state.append(state)
            buffer_reward.append(reward)
            buffer_next_state.append(next_state)
            state = next_state
            
            
            if t - t_start == T_MAX or done:
                t_start = t
                            
                if done:
                    state_value = 0
                else:
                    state_value = SESS.run(worker.value, {worker.inputs:np.expand_dims(state, axis=0)})[0][0]
                    
                buffer_target = []
                for r in buffer_reward[:-1][::-1]:
                    state_value = r + GAMMA * state_value
                    buffer_target.append(state_value)
                buffer_target.reverse()
                
                feed_dict = {worker.inputs: np.vstack(buffer_state[:-1]), 
                             worker.action: np.squeeze(np.vstack(buffer_reward[:-1]), axis=1), 
                             worker.target: np.expand_dims(np.array(buffer_target), axis=1)}
                SESS.run(worker.update_access, feed_dict)
                
                if done:
                    if episode > 99990:
                        entropy_loss, policy_loss, value_loss, total_loss = SESS.run(
                            [worker.a_entropy_loss, worker.a_policy_loss, worker.a_value_loss, worker.total_loss],
                            feed_dict)         
                        print (entropy_loss, policy_loss, value_loss, total_loss)

                        policy, advantage = SESS.run([worker.policy, worker.advantage], feed_dict)
                        print (policy)
                        print (advantage)
                
                    
                
                buffer_state = [buffer_state[-1]]
                buffer_reward = [buffer_reward[-1]]
                buffer_next_state = [buffer_next_state[-1]]
                
            t += 1
            if done:
                episode +=1
                episode_score_list.append(episode_score)
                break
            

In [None]:
%matplotlib inline
import pandas as pd

pd.Series(episode_score_list).plot(figsize=(16,9))

In [None]:
tf.trainable_variables()

In [None]:
episode_score_list