Skip to content

Commit

Permalink
Add cliprange for value fn (PPO2) (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored and hill-a committed Jun 5, 2019
1 parent fc9853c commit fefff48
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 28 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ For download links, please look at `Github release page <https://github.com/hill
Pre-Release 2.6.0a0 (WIP)
-------------------------

**Hindsight Experience Replay (HER) - Reloaded**
**Hindsight Experience Replay (HER) - Reloaded | get/load parameters**

- revamped HER implementation: clean re-implementation from scratch, now supports DQN, SAC and DDPG
- **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x)
Expand All @@ -27,6 +27,8 @@ Pre-Release 2.6.0a0 (WIP)
- added ``load_parameters`` and ``get_parameters`` to base RL class.
With these methods, users are able to load and get parameters to/from existing model, without touching tensorflow. (@Miffyli)
- **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli)
- added specific hyperparameter for PPO2 to clip the value function (``cliprange_vf``)
- fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed.

**Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
Expand Down
84 changes: 64 additions & 20 deletions stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class PPO2(ActorCriticRLModel):
the number of environments run in parallel should be a multiple of nminibatches.
:param noptepochs: (int) Number of epoch when optimizing the surrogate
:param cliprange: (float or callable) Clipping parameter, it can be a function
:param cliprange_vf: (float or callable) Clipping parameter for the value function, it can be a function.
This is a parameter specific to the OpenAI implementation. If None is passed (default),
then `cliprange` (that is used for the policy) will be used.
IMPORTANT: this clipping depends on the reward scaling.
To deactivate value function clipping (and recover the original PPO implementation),
you have to pass a negative value (e.g. -1).
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
Expand All @@ -42,15 +48,16 @@ class PPO2(ActorCriticRLModel):
"""

def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5,
max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, verbose=0,
tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, cliprange_vf=None,
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
full_tensorboard_log=False):

super(PPO2, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)

self.learning_rate = learning_rate
self.cliprange = cliprange
self.cliprange_vf = cliprange_vf
self.n_steps = n_steps
self.ent_coef = ent_coef
self.vf_coef = vf_coef
Expand Down Expand Up @@ -143,11 +150,36 @@ def setup_model(self):
self.entropy = tf.reduce_mean(train_model.proba_distribution.entropy())

vpred = train_model.value_flat
vpredclipped = self.old_vpred_ph + tf.clip_by_value(
train_model.value_flat - self.old_vpred_ph, - self.clip_range_ph, self.clip_range_ph)

# Value function clipping: not present in the original PPO
if self.cliprange_vf is None:
# Default behavior (legacy from OpenAI baselines):
# use the same clipping as for the policy
self.clip_range_vf_ph = self.clip_range_ph
self.cliprange_vf = self.cliprange
elif isinstance(self.cliprange_vf, (float, int)) and self.cliprange_vf < 0:
# Original PPO implementation: no value function clipping
self.clip_range_vf_ph = None
else:
# Last possible behavior: clipping range
# specific to the value function
self.clip_range_vf_ph = tf.placeholder(tf.float32, [], name="clip_range_vf_ph")

if self.clip_range_vf_ph is None:
# No clipping
vpred_clipped = train_model.value_flat
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
vpred_clipped = self.old_vpred_ph + \
tf.clip_by_value(train_model.value_flat - self.old_vpred_ph,
- self.clip_range_vf_ph, self.clip_range_vf_ph)


vf_losses1 = tf.square(vpred - self.rewards_ph)
vf_losses2 = tf.square(vpredclipped - self.rewards_ph)
vf_losses2 = tf.square(vpred_clipped - self.rewards_ph)
self.vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))

ratio = tf.exp(self.old_neglog_pac_ph - neglogpac)
pg_losses = -self.advs_ph * ratio
pg_losses2 = -self.advs_ph * tf.clip_by_value(ratio, 1.0 - self.clip_range_ph, 1.0 +
Expand Down Expand Up @@ -184,6 +216,9 @@ def setup_model(self):
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))
tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph))
tf.summary.scalar('clip_range', tf.reduce_mean(self.clip_range_ph))
if self.clip_range_vf_ph is not None:
tf.summary.scalar('clip_range_vf', tf.reduce_mean(self.clip_range_vf_ph))

tf.summary.scalar('old_neglog_action_probabilty', tf.reduce_mean(self.old_neglog_pac_ph))
tf.summary.scalar('old_value_pred', tf.reduce_mean(self.old_vpred_ph))

Expand All @@ -210,7 +245,7 @@ def setup_model(self):
self.summary = tf.summary.merge_all()

def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, values, neglogpacs, update,
writer, states=None):
writer, states=None, cliprange_vf=None):
"""
Training of PPO2 Algorithm
Expand All @@ -227,16 +262,21 @@ def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, va
:param states: (np.ndarray) For recurrent policies, the internal state of the recurrent model
:return: policy gradient loss, value function loss, policy entropy,
approximation of kl divergence, updated clipping range, training update operation
:param cliprange_vf: (float) Clipping factor for the value function
"""
advs = returns - values
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
td_map = {self.train_model.obs_ph: obs, self.action_ph: actions, self.advs_ph: advs, self.rewards_ph: returns,
td_map = {self.train_model.obs_ph: obs, self.action_ph: actions,
self.advs_ph: advs, self.rewards_ph: returns,
self.learning_rate_ph: learning_rate, self.clip_range_ph: cliprange,
self.old_neglog_pac_ph: neglogpacs, self.old_vpred_ph: values}
if states is not None:
td_map[self.train_model.states_ph] = states
td_map[self.train_model.dones_ph] = masks

if cliprange_vf is not None and cliprange_vf >= 0:
td_map[self.clip_range_vf_ph] = cliprange_vf

if states is None:
update_fac = self.n_batch // self.nminibatches // self.noptepochs + 1
else:
Expand Down Expand Up @@ -267,6 +307,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo
# Transform to callable if needed
self.learning_rate = get_schedule_fn(self.learning_rate)
self.cliprange = get_schedule_fn(self.cliprange)
cliprange_vf = get_schedule_fn(self.cliprange_vf)

new_tb_log = self._init_num_timesteps(reset_num_timesteps)

Expand All @@ -280,16 +321,18 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo
ep_info_buf = deque(maxlen=100)
t_first_start = time.time()

nupdates = total_timesteps // self.n_batch
for update in range(1, nupdates + 1):
n_updates = total_timesteps // self.n_batch
for update in range(1, n_updates + 1):
assert self.n_batch % self.nminibatches == 0
batch_size = self.n_batch // self.nminibatches
t_start = time.time()
frac = 1.0 - (update - 1.0) / nupdates
frac = 1.0 - (update - 1.0) / n_updates
lr_now = self.learning_rate(frac)
cliprangenow = self.cliprange(frac)
cliprange_now = self.cliprange(frac)
cliprange_vf_now = cliprange_vf(frac)
# true_reward is the reward without discount
obs, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = runner.run()
self.num_timesteps += self.n_batch
ep_info_buf.extend(ep_infos)
mb_loss_vals = []
if states is None: # nonrecurrent version
Expand All @@ -303,9 +346,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo
end = start + batch_size
mbinds = inds[start:end]
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, writer=writer,
update=timestep))
self.num_timesteps += (self.n_batch * self.noptepochs) // batch_size * update_fac
mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, writer=writer,
update=timestep, cliprange_vf=cliprange_vf_now))
else: # recurrent version
update_fac = self.n_batch // self.nminibatches // self.noptepochs // self.n_steps + 1
assert self.n_envs % self.nminibatches == 0
Expand All @@ -322,9 +364,9 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo
mb_flat_inds = flat_indices[mb_env_inds].ravel()
slices = (arr[mb_flat_inds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mb_states = states[mb_env_inds]
mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, update=timestep,
writer=writer, states=mb_states))
self.num_timesteps += (self.n_envs * self.noptepochs) // envs_per_batch * update_fac
mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, update=timestep,
writer=writer, states=mb_states,
cliprange_vf=cliprange_vf_now))

loss_vals = np.mean(mb_loss_vals, axis=0)
t_now = time.time()
Expand All @@ -339,7 +381,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo
if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, returns)
logger.logkv("serial_timesteps", update * self.n_steps)
logger.logkv("nupdates", update)
logger.logkv("n_updates", update)
logger.logkv("total_timesteps", self.num_timesteps)
logger.logkv("fps", fps)
logger.logkv("explained_variance", float(explained_var))
Expand Down Expand Up @@ -371,6 +413,7 @@ def save(self, save_path):
"nminibatches": self.nminibatches,
"noptepochs": self.noptepochs,
"cliprange": self.cliprange,
"cliprange_vf": self.cliprange_vf,
"verbose": self.verbose,
"policy": self.policy,
"observation_space": self.observation_space,
Expand Down Expand Up @@ -474,8 +517,9 @@ def get_schedule_fn(value_schedule):
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, float):
value_schedule = constfn(value_schedule)
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constfn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
Expand Down
13 changes: 6 additions & 7 deletions tests/test_a2c_conv.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import tensorflow as tf
import gym
import numpy as np
import tensorflow as tf

from stable_baselines.a2c.utils import conv
import gym
from stable_baselines.common.input import observation_input


ENV_ID = 'BreakoutNoFrameskip-v4'
SEED = 3


def test_conv_kernel():
"""
test convolution kernel with various input formats
"""
"""Test convolution kernel with various input formats."""
filter_size_1 = 4 # The size of squared filter for the first layer
filter_size_2 = (3, 5) # The size of non-squared filter for the second layer
target_shape_1 = [2, 52, 40, 32] # The desired shape of the first layer
Expand All @@ -24,8 +24,7 @@ def test_conv_kernel():
env = gym.make(ENV_ID)
ob_space = env.observation_space

graph = tf.Graph()
with graph.as_default():
with tf.Graph().as_default():
_, scaled_images = observation_input(ob_space, n_batch, scale=scale)
activ = tf.nn.relu
layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=filter_size_1, stride=4
Expand Down
20 changes: 20 additions & 0 deletions tests/test_ppo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import pytest

from stable_baselines import PPO2


@pytest.mark.parametrize("cliprange", [0.2, lambda x: 0.1 * x])
@pytest.mark.parametrize("cliprange_vf", [None, 0.2, lambda x: 0.3 * x, -1.0])
def test_clipping(cliprange, cliprange_vf):
"""Test the different clipping (policy and vf)"""
model = PPO2('MlpPolicy', 'CartPole-v1',
cliprange=cliprange, cliprange_vf=cliprange_vf).learn(1000)
model.save('./ppo2_clip.pkl')
env = model.get_env()
model = PPO2.load('./ppo2_clip.pkl', env=env)
model.learn(1000)

if os.path.exists('./ppo2_clip.pkl'):
os.remove('./ppo2_clip.pkl')

0 comments on commit fefff48

Please sign in to comment.