In [1]:
import tensorflow as tf
import gym
import numpy as np
import datetime

tf.random.set_seed(0)

In [4]:
from tensorflow.keras import layers as kl

In [5]:
class GaussianSample(kl.Layer):
    def __init__(self, action_dim):
        super(GaussianSample, self).__init__(name='GaussianSample')

        # s_init = tf.constant_initializer(np.exp(log_std))
        log_std = -0.53 * np.ones(action_dim, dtype=np.float32)
        self.log_std = tf.Variable(initial_value=log_std,
                               name='log_std', trainable=True)

    def call(self, inputs):
        # If training return dist, else not
        # So better to always return everything
        # std = tf.zeros_like(inputs) + self.std
        return distributions.Normal(loc=inputs, scale=tf.exp(self.std))

In [6]:
def get_actor(obs_dim: int, act_dim: int):
    model = tf.keras.Sequential([
        kl.Dense(32, input_shape=(obs_dim,), activation=tf.keras.activations.tanh),
        kl.Dense(act_dim),
        GaussianSample(act_dim),
    ])
    
    return model

In [7]:
def get_opt_fn(model, lr=3e-4):
    opt = tf.keras.optimizers.Adam(lr)
    
    @tf.function
    def step_fn(obs_no, act_na, adv_n):
        with tf.GradientTape() as tape:
            logp = tf.reduce_sum(
                model(obs_no[None]).log_prob(act_na),
                axis=1,
            )
            
            loss = tf.reduce_mean(-logp * adv_n)
        
        grad = tape.gradient(loss, model.trainable_variables)
        opt.apply_gradients(zip(grad, model.trainable_variables))
        
    return opt, setp_fn

In [10]:
def train(env, epochs=200, buffer_size=4096):
    assert isinstance(env.observation_space, Box), \
        "This is only for continuous observation space"
    assert isinstance(env.action_space, Box), \
        "This is only for continuous action space"
    
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    
    model = get_actor(obs_dim, act_dim)
    
    opt, step_fn = get_opt_fn(model)
    
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for reward-to-go weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # collect experience by acting in the environment with current policy
        while True:
            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = model(obs[None]).sample()[0]
            obs, rew, done, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a_t|s_t) is reward-to-go from t
                batch_weights += list(reward_to_go(ep_rews))

                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []

                # end experience loop if we have enough of it
                if len(batch_obs) > buffer_size:
                    break

        # take a single policy gradient update step
        step_fn(batch_obs, batch_acts, batch_weights)
        return batch_rets, batch_lens