In [1]:
import tensorflow as tf
import gym
import numpy as np
from tensorflow_probability import distributions as dists
import tensorflow.keras.layers as kl
import datetime

from rl_agents.env_utils import rollouts_generator, get_adv_vtarg
from rl_agents.vpg.agent import VPG_Agent
from rl_agents.ppo.agent import PPO_Agent
from rl_agents.policies.categorical import CategoricalActor
from rl_agents.policies.gaussian import GaussianActor
from rl_agents.common import Critic
from rl_agents.trainer.sensei import Sensei

from gym.spaces import Box, Discrete

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%load_ext tensorboard.notebook

tf.random.set_seed(0)

## Create GYM environment
Use Pendulum-v0 for now

In [59]:
# env_fn = lambda: gym.make('MountainCarContinuous-v0')
# env_fn = lambda: gym.make('Pendulum-v0')
env_fn = lambda: gym.make('MountainCar-v0')
# env_fn = lambda: gym.make('CartPole-v0')
env = env_fn()
is_continuous = isinstance(env.action_space, gym.spaces.Box)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape if is_continuous else env.action_space.n

# Proximal Policy Optimization

## Initialization

In [None]:
actor_ppo = GaussianActor(obs_dim, act_dim) if is_continuous else CategoricalActor(obs_dim, act_dim)
critic_ppo = Critic(obs_dim)
jen_ppo = PPO_Agent(actor_ppo, critic_ppo, is_continuous, act_dim)
generator_ppo = rollouts_generator(jen_ppo, env, is_continuous, horizon=2048)

alg_name = "PPO"
num_ite = 200
lam = 0.95
gamma = 0.99
epochs_actor = 20
epochs_critic = 40
sensei_ppo = Sensei(jen_ppo, alg_name, env_fn,
                    ite=num_ite, horizon=2048,
                    epochs_actor=epochs_actor, epochs_critic=epochs_critic,
                    gamma=gamma, gae_lambda=lam,
                    log_dir='logs')

In [None]:
sensei_ppo.train(batch_size=256)

# Vanilla Policy Gradient

## Initialization

In [60]:
actor_vpg = GaussianActor(obs_dim, act_dim) if is_continuous else CategoricalActor(obs_dim, act_dim)
critic_vpg = Critic(obs_dim)
jen_vpg = VPG_Agent(actor_vpg, critic_vpg, is_continuous, act_dim)
generator_vpg = rollouts_generator(jen_vpg, env, is_continuous, horizon=2048)

alg_name = "VPG"
num_ite = 1
lam = 0.95
gamma = 0.99
epochs_actor = 1
epochs_critic = 40
sensei_vpg = Sensei(jen_vpg, alg_name, env_fn,
                    ite=num_ite, horizon=2048,
                    epochs_actor=epochs_actor, epochs_critic=epochs_critic,
                    gamma=gamma, gae_lambda=lam,
                    log_dir='logs')

# Training loop

In [73]:
sensei_vpg.train(batch_size=256)

In [74]:
rollout = generator_vpg.__next__()

In [78]:
rollout['rew'][:200]

array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1

In [76]:
adv, target_value = get_adv_vtarg(rollout, lam=lam, gamma=gamma)

In [77]:
target_value[:200]

array([-3.96471142, -3.9706069 , -3.97720704, -3.98444762, -3.99225975,
       -4.00039483, -4.00730507, -4.01304752, -4.01841905, -4.02412243,
       -4.03010987, -4.03619309, -4.04093503, -4.04498506, -4.04772361,
       -4.05047359, -4.05188431, -4.05341157, -4.05504403, -4.05676964,
       -4.05857579, -4.06039037, -4.0614741 , -4.06176133, -4.06051846,
       -4.05774322, -4.0534362 , -4.04760135, -4.04040065, -4.0331868 ,
       -4.02451122, -4.01439514, -4.00286724, -3.99007283, -3.97683699,
       -3.96337231, -3.95056101, -3.93826292, -3.92501852, -3.91119817,
       -3.898488  , -3.88702682, -3.87661146, -3.86573928, -3.85448015,
       -3.84309082, -3.83265714, -3.82409699, -3.81750796, -3.81257684,
       -3.80810951, -3.80538674, -3.80319289, -3.80298495, -3.80375103,
       -3.80466731, -3.80611497, -3.80951712, -3.81420978, -3.82076723,
       -3.82848848, -3.8379211 , -3.84833028, -3.86024279, -3.87289522,
       -3.88665525, -3.90011229, -3.91458036, -3.92902473, -3.94