In [22]:
import tensorflow as tf
import gym
import numpy as np
from tensorflow_probability import distributions as dists
import tensorflow.keras.layers as kl
import datetime

from rl_agents.env_utils import rollouts_generator, get_adv_vtarg
from rl_agents.ppo.policy import Actor, Critic
from rl_agents.ppo.agent import PPO_Agent

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%load_ext tensorboard.notebook

tf.random.set_seed(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard.notebook extension is already loaded. To reload it, use:
  %reload_ext tensorboard.notebook


## Create GYM environment
Use Pendulum-v0 for now

In [23]:
env = gym.make('Pendulum-v0')
is_continuous = isinstance(env.action_space, gym.spaces.Box)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape

In [24]:
actor = Actor(obs_dim, act_dim, is_continuous)
critic = Critic(obs_dim)
jen = PPO_Agent(actor, critic)
generator = rollouts_generator(jen, env, horizon=2048)

# Training loop

In [25]:
num_ite = 200
lam = 0.95
gamma = 0.99
num_epochs = 10

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

In [28]:
for i in range(num_ite):
    rollout = generator.__next__()
    advantage, target_value = get_adv_vtarg(rollout, lam=lam, gamma=gamma)
    jen.run_ite(rollout['ob'], rollout['ac'], rollout['log_probs'], rollout['locs'], target_value, advantage,
                 epochs=num_epochs)
    with train_summary_writer.as_default():
        tf.summary.scalar('reward mean', np.array(rollout["ep_rets"]).mean(), step=i*num_epochs)
    
    if i % 50 == 0 or i == num_ite-1:
        actor.save_weights(train_log_dir+'/_actor_'+str(i), save_format='tf')
        critic.save_weights(train_log_dir+'/_critic_'+str(i), save_format='tf')
    #    mean, std = rewards.mean(), rewards.std()
    #    print('mean', mean)
    #    print('std', std)

In [27]:
%tensorboard --logdir logs/gradient_tape --port=8001

In [9]:
actor2 = Actor(obs_dim, act_dim, is_continuous)
critic2 = Critic(obs_dim)
vero2 = PPO_Agent(actor2, critic2)
generator2 = rollouts_generator(vero2, env, horizon=2048)

num_ite = 200
lam = 0.95
gamma = 0.99
num_epochs = 10

for i in range(num_ite):
    print('#### iteration ###', i)
    rollout = generator2.__next__()
    # print(rollout['ac'][0:10])
    advantage, target_value = get_adv_vtarg(rollout, lam=lam, gamma=gamma)
    vero2.run_ite(rollout['ob'], rollout['ac'], rollout['log_probs'], rollout['locs'], target_value, advantage,
                  epochs=num_epochs)

#### iteration ### 0
#### iteration ### 1
#### iteration ### 2


In [13]:
generator = rollouts_generator(jen, env, horizon=210)

roll = generator.__next__()

adv, tar = get_adv_vtarg(roll, lam=0.95, gamma=0.99)

In [14]:
roll.keys()

dict_keys(['ob', 'ac', 'rew', 'new', 'vpred', 'next_vpred', 'ep_rets', 'ep_lens', 'log_probs', 'locs'])

In [20]:
roll['rew']

array([ -4.46303753,  -4.76982944,  -5.25897669,  -6.07663208,
        -7.03088186,  -8.18696987,  -9.31243049, -10.69244418,
       -10.95152102,  -9.7369286 ,  -8.52403479,  -7.33262405,
        -6.33323444,  -5.46338617,  -4.81855076,  -4.31588751,
        -3.96274289,  -3.78471415,  -3.82132724,  -4.06488331,
        -4.4387805 ,  -5.12325702,  -5.94074073,  -6.92105055,
        -8.1546692 ,  -9.62457392, -11.01847499, -11.03876518,
        -9.67710229,  -8.35781227,  -7.23547672,  -6.23038877,
        -5.35118174,  -4.74135091,  -4.34147534,  -4.15673628,
        -4.20633771,  -4.40548795,  -4.77051669,  -5.4166298 ,
        -6.2522507 ,  -7.21453707,  -8.30923869,  -9.48258746,
       -10.94903283, -10.77487004,  -9.51685608,  -8.30952711,
        -7.22602033,  -6.23932421,  -5.43136004,  -4.77019914,
        -4.24789933,  -3.9483908 ,  -3.8455787 ,  -3.96949791,
        -4.38951722,  -5.04225743,  -5.89600863,  -6.99138743,
        -8.27813924,  -9.6631865 , -11.05091581, -10.94

In [18]:
roll['new']

array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [19]:
adv

array([-113.16341038, -115.40258509, -117.48752021, -119.16545706,
       -120.15082696, -120.20539407, -119.07717818, -116.67422881,
       -112.67010627, -108.1655998 , -104.68843727, -102.31101047,
       -101.05575439, -100.81008263, -101.46417902, -102.85795727,
       -104.88871227, -107.46541567, -110.40464296, -113.49761451,
       -116.48118296, -119.30063722, -121.49202718, -122.92090023,
       -123.38521643, -122.55147516, -120.07482376, -115.95505363,
       -111.53577012, -108.26915478, -106.19721396, -105.16927249,
       -105.10260466, -105.95817729, -107.49167368, -109.5259315 ,
       -111.84807242, -114.32451998, -116.73359017, -118.87747008,
       -120.52273595, -121.42541423, -121.3834648 , -120.19573887,
       -117.67749763, -113.47425845, -109.21955684, -106.05148271,
       -103.97664095, -102.94436596, -102.90065337, -103.72532617,
       -105.34794524, -107.6275779 , -110.39059986, -113.47042896,
       -116.65268601, -119.52653943, -121.83333528, -123.34574