In [1]:
import tensorflow as tf
import gym
import numpy as np
from tensorflow_probability import distributions as dists
import tensorflow.keras.layers as kl
import datetime

from rl_agents.env_utils import rollouts_generator, get_adv_vtarg
from rl_agents.vpg.agent import VPG_Agent
from rl_agents.models import GaussianActor, CategoricalActor, Critic

from gym.spaces import Box, Discrete

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%load_ext tensorboard.notebook

tf.random.set_seed(0)

## Create GYM environment
Use Pendulum-v0 for now

In [2]:
# env = gym.make('Pendulum-v0')
env = gym.make('CartPole-v0')
is_continuous = isinstance(env.action_space, gym.spaces.Box)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape if is_continuous else env.action_space.n

# Vanilla Policy Gradient

## Initialization

In [5]:
actor = GaussianActor(obs_dim, act_dim) if is_continuous else CategoricalActor(obs_dim, act_dim)
critic = Critic(obs_dim)
jen = VPG_Agent(actor, critic, is_continuous, act_dim)
generator = rollouts_generator(jen, env, is_continuous, horizon=2048)

# Training loop

In [6]:
num_ite = 200
lam = 0.95
gamma = 0.99

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

for i in range(num_ite):
    rollout = generator.__next__()
    advantage, target_value = get_adv_vtarg(rollout, lam=lam, gamma=gamma)
    jen.run_ite(rollout['ob'], rollout['ac'], rollout['log_probs'], target_value, advantage, batch_size=512)
    with train_summary_writer.as_default():
        tf.summary.scalar('reward mean', np.array(rollout["ep_rets"]).mean(), step=i)
    
    if i % 50 == 0 or i == num_ite-1:
        actor.save_weights(train_log_dir+'/_actor_'+str(i), save_format='tf')
        critic.save_weights(train_log_dir+'/_critic_'+str(i), save_format='tf')

In [None]:
%tensorboard --logdir logs/gradient_tape --port=8003

In [None]:
actor2 = Actor(obs_dim, act_dim, is_continuous)
critic2 = Critic(obs_dim)
vero2 = PPO_Agent(actor2, critic2)
generator2 = rollouts_generator(vero2, env, horizon=2048)

num_ite = 200
lam = 0.95
gamma = 0.99
num_epochs = 10

for i in range(num_ite):
    print('#### iteration ###', i)
    rollout = generator2.__next__()
    # print(rollout['ac'][0:10])
    advantage, target_value = get_adv_vtarg(rollout, lam=lam, gamma=gamma)
    vero2.run_ite(rollout['ob'], rollout['ac'], rollout['log_probs'], rollout['locs'], target_value, advantage,
                  epochs=num_epochs)

In [None]:
generator = rollouts_generator(jen, env, horizon=210)

roll = generator.__next__()

adv, tar = get_adv_vtarg(roll, lam=0.95, gamma=0.99)

In [None]:
roll.keys()

In [None]:
roll['rew']

In [None]:
roll['new']

In [None]:
adv