Skip to content

Commit

Permalink
Merge pull request #42 from hill-a/ppo2-cont-nan
Browse files Browse the repository at this point in the history
Hotfix Continuous Actions
  • Loading branch information
hill-a committed Oct 2, 2018
2 parents 0e4805d + 61a37e6 commit f0fef3f
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 35 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ Pre Release 2.0.1.a0 (WIP)
- fixes for DQN action_probability
- re-added double DQN + refactored DQN policies **breaking changes**
- replaced `async` with `async_eigen_decomp` in ACKTR/KFAC for python 3.7 compat

- removed action clipping for prediction of continuous actions (see issue #36)
- fixed NaN issue due to clipping the continuous action in the wrong place (issue #36)

Release 2.0.0 (2018-09-18)
--------------------------
Expand Down
5 changes: 5 additions & 0 deletions docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ If you need more control on the policy architecture, You can also create a custo

CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)

.. warning::
For all algorithms (except DDPG), continuous actions are only clipped during training
(to avoid out of bound error). However, you have to manually clip the action when using
the `predict()` method.

.. rubric:: Available Policies

.. autosummary::
Expand Down
7 changes: 6 additions & 1 deletion stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time

import gym
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -292,7 +293,11 @@ def run(self):
mb_actions.append(actions)
mb_values.append(values)
mb_dones.append(self.dones)
obs, rewards, dones, _ = self.env.step(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.env.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
obs, rewards, dones, _ = self.env.step(clipped_actions)
self.states = states
self.dones = dones
self.obs = obs
Expand Down
6 changes: 5 additions & 1 deletion stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def run(self):
mb_actions.append(actions)
mb_mus.append(mus)
mb_dones.append(self.dones)
obs, rewards, dones, _ = self.env.step(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.env.action_space, Box):
clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
obs, rewards, dones, _ = self.env.step(clipped_actions)
# states information for statefull models like LSTM
self.states = states
self.dones = dones
Expand Down
43 changes: 16 additions & 27 deletions stable_baselines/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import math_ops
import numpy as np
from gym import spaces

from stable_baselines.a2c.utils import linear
Expand Down Expand Up @@ -210,35 +210,32 @@ def sample_dtype(self):


class DiagGaussianProbabilityDistributionType(ProbabilityDistributionType):
def __init__(self, size, bounds=(-np.inf, np.inf)):
def __init__(self, size):
"""
The probability distribution type for multivariate gaussian input
:param size: (int) the number of dimensions of the multivariate gaussian
:param bounds: (float, float) the lower and upper bounds limit for the action space
"""
self.size = size
self.bounds = bounds

def probability_distribution_class(self):
return DiagGaussianProbabilityDistribution

def proba_distribution_from_flat(self, flat, bounds=(-np.inf, np.inf)):
def proba_distribution_from_flat(self, flat):
"""
returns the probability distribution from flat probabilities
:param flat: ([float]) the flat probabilities
:param bounds: (float, float) the lower and upper bounds limit for the action space
:return: (ProbabilityDistribution) the instance of the ProbabilityDistribution associated
"""
return self.probability_distribution_class()(flat, bounds)
return self.probability_distribution_class()(flat)

def proba_distribution_from_latent(self, pi_latent_vector, vf_latent_vector, init_scale=1.0, init_bias=0.0):
mean = linear(pi_latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
q_values = linear(vf_latent_vector, 'q', self.size, init_scale=init_scale, init_bias=init_bias)
return self.proba_distribution_from_flat(pdparam, self.bounds), mean, q_values
return self.proba_distribution_from_flat(pdparam), mean, q_values

def param_shape(self):
return [2 * self.size]
Expand Down Expand Up @@ -319,7 +316,7 @@ def entropy(self):
return tf.reduce_sum(p_0 * (tf.log(z_0) - a_0), axis=-1)

def sample(self):
uniform = tf.random_uniform(tf.shape(self.logits))
uniform = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(uniform)), axis=-1)

@classmethod
Expand Down Expand Up @@ -374,29 +371,24 @@ def fromflat(cls, flat):


class DiagGaussianProbabilityDistribution(ProbabilityDistribution):
def __init__(self, flat, bounds=(-np.inf, np.inf)):
def __init__(self, flat):
"""
Probability distributions from multivariate gaussian input
:param flat: ([float]) the multivariate gaussian input data
:param bounds: (float, float) the lower and upper bounds limit for the action space
"""
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape) - 1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
self.bounds = bounds

def flatparam(self):
return self.flat

def mode(self):
low = self.bounds[0]
high = self.bounds[1]

# clip the output (clip_by_value does not broadcast correctly)
return tf.minimum(tf.maximum(self.mean, low), high)
# Bounds are taken into account outside this class (during training only)
return self.mean

def neglogp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
Expand All @@ -412,22 +404,19 @@ def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)

def sample(self):
low = self.bounds[0]
high = self.bounds[1]

# clip the output (clip_by_value does not broadcast correctly)
return tf.minimum(tf.maximum(self.mean + self.std * tf.random_normal(tf.shape(self.mean)), low), high)
# Bounds are taken into acount outside this class (during training only)
# Otherwise, it changes the distribution and breaks PPO2 for instance
return self.mean + self.std * tf.random_normal(tf.shape(self.mean), dtype=self.mean.dtype)

@classmethod
def fromflat(cls, flat, bounds=(-np.inf, np.inf)):
def fromflat(cls, flat):
"""
Create an instance of this from new multivariate gaussian input
:param flat: ([float]) the multivariate gaussian input data
:param bounds: (float, float) the lower and upper bounds limit for the action space
:return: (ProbabilityDistribution) the instance from the given multivariate gaussian input data
"""
return cls(flat, bounds)
return cls(flat)


class BernoulliProbabilityDistribution(ProbabilityDistribution):
Expand Down Expand Up @@ -484,7 +473,7 @@ def make_proba_dist_type(ac_space):
"""
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1, "Error: the action space must be a vector"
return DiagGaussianProbabilityDistributionType(ac_space.shape[0], (ac_space.low, ac_space.high))
return DiagGaussianProbabilityDistributionType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalProbabilityDistributionType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
Expand Down
9 changes: 7 additions & 2 deletions stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
from collections import deque
import sys
import multiprocessing
from collections import deque

import gym
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -399,7 +400,11 @@ def run(self):
mb_values.append(values)
mb_neglogpacs.append(neglogpacs)
mb_dones.append(self.dones)
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.env.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
self.obs[:], rewards, self.dones, infos = self.env.step(clipped_actions)
for info in infos:
maybeep_info = info.get('episode')
if maybeep_info:
Expand Down
12 changes: 9 additions & 3 deletions stable_baselines/trpo_mpi/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gym
import numpy as np

from stable_baselines.common.vec_env import VecEnv
Expand Down Expand Up @@ -83,11 +84,16 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
actions[i] = action[0]
prev_actions[i] = prevac

clipped_action = action
# Clip the actions to avoid out of bound error
if isinstance(env.action_space, gym.spaces.Box):
clipped_action = np.clip(action, env.action_space.low, env.action_space.high)

if gail:
rew = reward_giver.get_reward(observation, action[0])
observation, true_rew, done, _info = env.step(action[0])
rew = reward_giver.get_reward(observation, clipped_action[0])
observation, true_rew, done, _info = env.step(clipped_action[0])
else:
observation, rew, done, _info = env.step(action[0])
observation, rew, done, _info = env.step(clipped_action[0])
true_rew = rew
rews[i] = rew
true_rews[i] = true_rew
Expand Down

0 comments on commit f0fef3f

Please sign in to comment.