In [1]:
import time

import numpy as np
import tensorflow as tf
from sonnet.python.modules.basic import Linear
from sonnet.python.modules.base import AbstractModule

LEARNING_RATE = 1e-3
_EPSILON = 1e-6 # avoid nan
ENTROPY_BETA = 0.1
GAMMA = 0.99 # discount factor
VALUE_BETA = 0.5

In [2]:
def swich(tensor):
    return tensor * tf.nn.sigmoid(tensor + _EPSILON)

# shared neural network
def _build_shared_network(inputs):
    # inputs [batch_size, state_size]
    network = Linear(32, 'input_layer')(inputs)
    return swich(network)


class build_policy_network(AbstractModule):
    def __init__(self, name):
        super().__init__(name=name)
        
    def _build(self, inputs, action_size):
        shared_network = _build_shared_network(inputs)
        policy = Linear(32, 'policy_input')(shared_network)
        policy = swich(policy)
        policy = Linear(action_size, 'policy_output')(policy)
        return tf.nn.softmax(policy + _EPSILON) # avoid nan   

    
class build_value_network(AbstractModule):
    def __init__(self, name):
        super().__init__(name=name)
        
    def _build(self, inputs):
        shared_network = _build_shared_network(inputs)
        value = Linear(32, 'value_input')(shared_network)
        value = swich(value)
        value = Linear(1, 'value_output')(value)
        return value
    
    
# build approximate neural network
def _build_approximate_network(inputs, action_size):
    shared_network = _build_shared_network(inputs)
    
    policy = Linear(32, 'policy_input')(shared_network)
    policy = swich(policy)
    policy = Linear(action_size, 'policy_output')(policy)
    policy = tf.nn.softmax(policy + _EPSILON) # avoid nan   
    
    value = Linear(32, 'value_input')(shared_network)
    value = swich(value)
    value = Linear(1, 'value_output')(value)
    return policy, value

class simple_approximate_network(AbstractModule):
    def __init__(self, name):
        super().__init__(name=name)
    
    def _build(self, inputs, action_size):
        return _build_approximate_network(inputs, action_size)
        
# batch gather function from https://github.com/deepmind/dnc/blob/master/util.py
def _batch_gather(values, indices):
    """Returns batched `tf.gather` for every row in the input."""
    with tf.name_scope('batch_gather', values=[values, indices]):
        unpacked = zip(tf.unstack(values), tf.unstack(indices))
        result = [tf.gather(value, index) for value, index in unpacked]
        return tf.stack(result)

In [3]:
# global network for buffer weights and calculate gardients
class Access(object):
    def __init__(self, state_size, action_size, name='access'):
        #variable_scope for more clear graph, not necessary
        with tf.variable_scope(name):                   
            # placeholder for state and next state or you may like call it observation
            self.inputs = tf.placeholder(tf.float32, [None, state_size], 'inputs')     
            #self.network = simple_approximate_network('global_network')
            #self.policy, self.value = self.network(self.inputs, action_size)
            self.policy_network = build_policy_network('global_policy')
            self.value_network = build_value_network('global_value')
            self.policy = self.policy_network(self.inputs, action_size)
            self.value = self.value_network(self.inputs)
            
        self.optimizer_actor = tf.train.RMSPropOptimizer(LEARNING_RATE, name='optimizer_actor')
        self.optimizer_critic = tf.train.RMSPropOptimizer(LEARNING_RATE, name='optimizer_critic') 
        
    def get_trainable_variables(self):
        return [self.policy_network.get_variables(), self.value_network.get_variables()]

In [4]:
# local network for advantage actor-critic which are also know as A2C
class ACNet(object):
    def __init__(self, Access, state_size, action_size, name):
        self.Access = Access
        self.state_size = state_size
        self.action_size = action_size
        # action space, we assume that action space is range(0 to action_size-1)
        self.action_space = np.arange(action_size, dtype=np.int32)
        
        #variable_scope local graph, necessary
        with tf.variable_scope(name):
            # placeholder for state and next state or you may like call it observation
            self.inputs = tf.placeholder(tf.float32, [None, state_size], 'inputs')   
            self.action = tf.placeholder(tf.int32, [None], 'action')
            # n-step reward and discounted n next step value
            self.target = tf.placeholder(tf.float32, [None, 1], 'target')
            
            self.policy_network = build_policy_network('global_policy')
            self.value_network = build_value_network('global_value')
            self.policy = self.policy_network(self.inputs, action_size)
            self.value = self.value_network(self.inputs)
            
            self._build_loss_function()
            self.update_local, self.update_access = self._build_update()         
        
    def _build_loss_function(self):
        self.advantage = self.target - self.value
        # value loss
        self.value_loss = tf.reduce_mean(tf.square(self.advantage))
    
        # policy loss
        # get the stochastic policy action probability
        #policy_action = _batch_gather(self.policy, self.action)
        action_onehot = tf.one_hot(self.action, self.action_size)
        policy_action = tf.reduce_sum(self.policy * action_onehot, axis=1, keep_dims=True)
        log_policy_action = tf.log(policy_action + _EPSILON)
        # no grad pass through advantage in actor network 
        policy_loss = -tf.stop_gradient(self.advantage) * tf.expand_dims(log_policy_action, axis=1)
        # entropy loss
        entropy_loss = tf.reduce_sum(self.policy * tf.log(self.policy + _EPSILON), axis=1, keep_dims=True)
        self.policy_loss = tf.reduce_mean(policy_loss + ENTROPY_BETA * entropy_loss)

        # adjust some params
        self.a_policy_loss = -tf.reduce_mean(policy_loss)
        self.a_entropy_loss = tf.reduce_mean(ENTROPY_BETA * entropy_loss)
        self.a_value_loss = self.value_loss

    def _build_update(self):
        global_policy_params, global_value_params = self.Access.get_trainable_variables()
        local_policy_params = self.policy_network.get_variables()
        local_value_params = self.value_network.get_variables()
        
        policy_list = []
        for g,l in zip(global_policy_params, local_policy_params):
            policy_list.append(l.assign(g))
            
        value_list = []
        for g,l in zip(global_value_params, local_value_params):
            value_list.append(l.assign(g))        
        
        policy_grad = tf.gradients(self.policy_loss, list(local_policy_params))
        value_grad = tf.gradients(self.value_loss, list(local_value_params))
        
        policy_apply = self.Access.optimizer_actor.apply_gradients(zip(policy_grad, list(global_policy_params)))
        value_apply = self.Access.optimizer_critic.apply_gradients(zip(value_grad, list(global_value_params)))
        return [policy_list, value_list], [policy_apply, value_apply]    
    
    def choose_action(self, SESS, state):  # run by a local          
        policy = SESS.run(self.policy, {self.inputs: np.expand_dims(state, axis=0)})
        policy = np.squeeze(policy)
        action = np.random.choice(self.action_space, 1, p=policy)
        return action

In [5]:
class Worker(object):
    def __init__(self, master, name, state_size, action_size):
        self.env = gym.make(GAME).unwrapped
        self.master = master
        self.state_size = state_size
        self.action_size = action_size
        self.worker = ACNet(self.master, self.state_size, self.action_size, name)
        self.name = name
    
    def work(self, SESS):
        worker = self.worker
        env = self.env
        
        episode_score_list = []
        episode = 0
        while episode < MAX_EPISODES:

            t_start = t = 1
            state = env.reset()

            buffer_state = []
            buffer_reward = []
            buffer_next_state = []
            episode_score = 0
        
            while True:
                SESS.run(worker.update_local)
                action = worker.choose_action(SESS, state)[0]
                next_state, reward, done, info = env.step(action)
                episode_score += reward

                buffer_state.append(state)
                buffer_reward.append(reward)
                buffer_next_state.append(next_state)
                state = next_state

                if t - t_start == T_MAX or done:
                    t_start = t

                    if done:
                        state_value = 0
                    else:
                        state_value = SESS.run(worker.value, {worker.inputs:np.expand_dims(state, axis=0)})[0][0]

                    buffer_target = []
                    for r in buffer_reward[:-1][::-1]:
                        state_value = r + GAMMA * state_value
                        buffer_target.append(state_value)
                    buffer_target.reverse()

                    feed_dict = {worker.inputs: np.vstack(buffer_state[:-1]), 
                                 worker.action: np.squeeze(np.vstack(buffer_reward[:-1]), axis=1), 
                                 worker.target: np.expand_dims(np.array(buffer_target), axis=1)}
                    SESS.run(worker.update_access, feed_dict)

                    if done:
                        episode +=1
                        episode_score_list.append(episode_score)
                        
                        if self.name == 'W0':                            
                            entropy_loss, policy_loss, value_loss = SESS.run(
                                [worker.a_entropy_loss, worker.policy_loss, worker.value_loss],
                                feed_dict)   
                            #env.render()
                            print (episode_score, entropy_loss, policy_loss, value_loss)

                        buffer_state = [buffer_state[-1]]
                        buffer_reward = [buffer_reward[-1]]
                        buffer_next_state = [buffer_next_state[-1]]

                t += 1
                if done:
                    break

In [6]:
import gym
import multiprocessing
import threading
NUMS_CPU = multiprocessing.cpu_count()
#NUMS_CPU = 4

GAME = 'CartPole-v0'
env = gym.make(GAME)
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

MAX_EPISODES = int(5e2)
T_MAX = 32

tf.reset_default_graph()
SESS = tf.Session()
with tf.device("/cpu:0"):
    master = Access(state_size, action_size)
    worker_list = []
    for i in range(NUMS_CPU):
        worker_list.append(Worker(master, 'W%i'%i, state_size, action_size))
        
    COORD = tf.train.Coordinator()
    SESS.run(tf.global_variables_initializer())
    
    worker_threads = []
    for worker in worker_list:
        job = lambda: worker.work(SESS)
        t = threading.Thread(target=job)
        t.start()
        worker_threads.append(t)
    COORD.join(worker_threads)

67.0 -0.0690728 16.9736 929.255
8.0 -0.0681594 1.85021 16.5142
13.0 -0.0679205 3.08646 45.6668
27.0 -0.0673741 6.26292 191.54
16.0 -0.0644047 3.03586 67.937
10.0 -0.0636725 1.68992 25.4418
12.0 -0.0595169 1.56569 33.8104
20.0 -0.0599213 2.9736 100.325
13.0 -0.0537979 1.33919 39.5494
9.0 -0.0484234 0.581969 16.7977
10.0 -0.0407886 0.46293 20.4405
32.0 -0.0532922 4.12087 250.236
11.0 -0.0294003 0.298053 22.5776
18.0 -0.0405158 1.34934 77.9358
10.0 -0.0222212 0.161246 17.9435
12.0 -0.026631 0.359592 29.2258
10.0 -0.0267766 0.23591 17.7227
10.0 -0.0155875 0.0937022 15.7302
11.0 -0.0602616 2.17587 26.0406
11.0 -0.0207568 0.197015 20.1003
8.0 -0.0130463 0.0312534 7.12395
9.0 -0.0110782 0.03935 9.19829
10.0 -0.00886565 0.0382787 11.2904
9.0 -0.00848849 0.0229233 7.72034
10.0 -0.00603973 0.0219787 9.09342
9.0 -0.00545593 0.00896681 5.84954
9.0 -0.00403905 0.00600098 4.96478
10.0 -0.00310185 0.0071555 6.34283
9.0 -0.00310357 0.00297821 3.57995
10.0 -0.00211668 0.00446057 5.07829
10.0 -0.0017577

60.0 -0.0244588 1.11846 1.95231
79.0 -0.0234414 6.64773 15.747
80.0 -0.020829 0.793445 2.90335
61.0 -0.0237618 -8.10243 13.7493
52.0 -0.0229874 1.77176 1.3929
62.0 -0.0226789 -6.83869 8.218
91.0 -0.0217925 3.53705 10.2259
55.0 -0.0244538 -8.89328 14.6998
79.0 -0.0213301 7.17801 9.09276
66.0 -0.0228299 -8.79841 10.3953
98.0 -0.0257774 3.80205 13.9948
71.0 -0.0222232 0.822753 5.05082
76.0 -0.0230769 4.15497 11.3147
98.0 -0.0231503 8.26481 18.3546
64.0 -0.0237144 3.041 3.36579
66.0 -0.0221971 3.25853 4.27121
58.0 -0.0211835 -2.67936 4.90738
68.0 -0.0230291 1.17169 2.48242
114.0 -0.0223605 12.1596 39.0448
58.0 -0.0194022 -6.04338 5.96466
59.0 -0.0204425 -3.68892 5.82028
72.0 -0.0213054 -4.48288 9.7056
76.0 -0.0184606 4.97474 5.96318
106.0 -0.0162512 1.35701 7.16796
99.0 -0.0160262 4.56612 5.0121
93.0 -0.0155841 12.0195 21.8085
92.0 -0.0176723 4.68644 7.60297
66.0 -0.0208117 -4.10848 6.63492
75.0 -0.018398 -0.213886 1.05053
68.0 -0.0201589 1.98677 7.22528
62.0 -0.0180903 1.60862 5.07461
72.