Skip to content

Commit

Permalink
Return probability density in continuous action spaces, support retur…
Browse files Browse the repository at this point in the history
…ning log probabilities (#397)

* Support Gaussian probabilities and logp calculation

* Fix linting + missing normalizer

* Add & fix tests

* Fix Gaussian PDF calculation

* Bugfix in Gaussian probability calculation

* Address review comments
  • Loading branch information
AdamGleave authored and araffin committed Jul 18, 2019
1 parent 0e940c7 commit 2bc3c87
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 42 deletions.
4 changes: 4 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^

- Add support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian
policy in addition to the existing support for categorical stochastic policies.
- Add flag to `action_probability` to return log-probabilities.

Bug Fixes:
^^^^^^^^^^
- fixed a bug in ``traj_segment_generator`` where the ``episode_starts`` was wrongly recorded,
Expand Down
64 changes: 43 additions & 21 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,29 +331,29 @@ def predict(self, observation, state=None, mask=None, deterministic=False):
pass

@abstractmethod
def action_probability(self, observation, state=None, mask=None, actions=None):
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
"""
If ``actions`` is ``None``, then get the model's action probability distribution from a given observation
If ``actions`` is ``None``, then get the model's action probability distribution from a given observation.
depending on the action space the output is:
Depending on the action space the output is:
- Discrete: probability for each possible action
- Box: mean and standard deviation of the action output
However if ``actions`` is not ``None``, this function will return the probability that the given actions are
taken with the given parameters (observation, state, ...) on this model.
.. warning::
When working with continuous probability distribution (e.g. Gaussian distribution for continuous action)
the probability of taking a particular action is exactly zero.
See http://blog.christianperone.com/2019/01/ for a good explanation
taken with the given parameters (observation, state, ...) on this model. For discrete action spaces, it
returns the probability mass; for continuous action spaces, the probability density. This is since the
probability mass will always be zero in continuous spaces, see http://blog.christianperone.com/2019/01/
for a good explanation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param actions: (np.ndarray) (OPTIONAL) For calculating the likelihood that the given actions are chosen by
the model for each of the given parameters. Must have the same number of actions and observations.
(set to None to return the complete action probability distribution)
:return: (np.ndarray) the model's action probability
:param logp: (bool) (OPTIONAL) When specified with actions, returns probability in log-space.
This has no effect if actions is None.
:return: (np.ndarray) the model's (log) action probability
"""
pass

Expand Down Expand Up @@ -592,7 +592,7 @@ def predict(self, observation, state=None, mask=None, deterministic=False):

return clipped_actions, states

def action_probability(self, observation, state=None, mask=None, actions=None):
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
if state is None:
state = self.initial_state
if mask is None:
Expand All @@ -609,46 +609,68 @@ def action_probability(self, observation, state=None, mask=None, actions=None):
return None

if actions is not None: # comparing the action distribution, to given actions
prob = None
logprob = None
actions = np.array([actions])
if isinstance(self.action_space, gym.spaces.Discrete):
actions = actions.reshape((-1,))
assert observation.shape[0] == actions.shape[0], \
"Error: batch sizes differ for actions and observations."
actions_proba = actions_proba[np.arange(actions.shape[0]), actions]
prob = actions_proba[np.arange(actions.shape[0]), actions]

elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
actions = actions.reshape((-1, len(self.action_space.nvec)))
assert observation.shape[0] == actions.shape[0], \
"Error: batch sizes differ for actions and observations."
# Discrete action probability, over multiple categories
actions = np.swapaxes(actions, 0, 1) # swap axis for easier categorical split
actions_proba = np.prod([proba[np.arange(act.shape[0]), act]
prob = np.prod([proba[np.arange(act.shape[0]), act]
for proba, act in zip(actions_proba, actions)], axis=0)

elif isinstance(self.action_space, gym.spaces.MultiBinary):
actions = actions.reshape((-1, self.action_space.n))
assert observation.shape[0] == actions.shape[0], \
"Error: batch sizes differ for actions and observations."
# Bernoulli action probability, for every action
actions_proba = np.prod(actions_proba * actions + (1 - actions_proba) * (1 - actions), axis=1)
prob = np.prod(actions_proba * actions + (1 - actions_proba) * (1 - actions), axis=1)

elif isinstance(self.action_space, gym.spaces.Box):
warnings.warn("The probabilty of taken a given action is exactly zero for a continuous distribution."
"See http://blog.christianperone.com/2019/01/ for a good explanation")
actions_proba = np.zeros((observation.shape[0], 1), dtype=np.float32)
actions = actions.reshape((-1, ) + self.action_space.shape)
mean, logstd = actions_proba
std = np.exp(logstd)

n_elts = np.prod(mean.shape[1:]) # first dimension is batch size
log_normalizer = n_elts/2 * np.log(2 * np.pi) + 1/2 * np.sum(logstd, axis=1)

# Diagonal Gaussian action probability, for every action
logprob = -np.sum(np.square(actions - mean) / (2 * std), axis=1) - log_normalizer

else:
warnings.warn("Warning: action_probability not implemented for {} actions space. Returning None."
.format(type(self.action_space).__name__))
return None

# Return in space (log or normal) requested by user, converting if necessary
if logp:
if logprob is None:
logprob = np.log(prob)
ret = logprob
else:
if prob is None:
prob = np.exp(logprob)
ret = prob

# normalize action proba shape for the different gym spaces
actions_proba = actions_proba.reshape((-1, 1))
ret = ret.reshape((-1, 1))
else:
ret = actions_proba

if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions_proba = actions_proba[0]
ret = ret[0]

return actions_proba
return ret

def get_parameter_list(self):
return self.params
Expand Down Expand Up @@ -710,7 +732,7 @@ def predict(self, observation, state=None, mask=None, deterministic=False):
pass

@abstractmethod
def action_probability(self, observation, state=None, mask=None, actions=None):
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
pass

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def predict(self, observation, state=None, mask=None, deterministic=True):

return actions, None

def action_probability(self, observation, state=None, mask=None, actions=None):
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
observation = np.array(observation)

if actions is not None:
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def predict(self, observation, state=None, mask=None, deterministic=True):

return actions, None

def action_probability(self, observation, state=None, mask=None, actions=None):
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

Expand All @@ -318,6 +318,8 @@ def action_probability(self, observation, state=None, mask=None, actions=None):
actions_proba = actions_proba[np.arange(actions.shape[0]), actions]
# normalize action proba shape
actions_proba = actions_proba.reshape((-1, 1))
if logp:
actions_proba = np.log(actions_proba)

if not vectorized_env:
if state is not None:
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/her/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def _check_obs(self, observation):
def predict(self, observation, state=None, mask=None, deterministic=True):
return self.model.predict(self._check_obs(observation), state, mask, deterministic)

def action_probability(self, observation, state=None, mask=None, actions=None):
return self.model.action_probability(self._check_obs(observation), state, mask, actions)
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
return self.model.action_probability(self._check_obs(observation), state, mask, actions, logp)

def _save_to_file(self, save_path, data=None, params=None):
# HACK to save the replay wrapper
Expand Down
17 changes: 6 additions & 11 deletions stable_baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,19 +497,14 @@ def learn(self, total_timesteps, callback=None, seed=None,
infos_values = []
return self

def action_probability(self, observation, state=None, mask=None, actions=None):
if actions is None:
warnings.warn("Even thought SAC has a Gaussian policy, it cannot return a distribution as it "
"is squashed by an tanh before being scaled and ouputed. Therefore 'action_probability' "
"will only work with the 'actions' keyword argument being used. Returning None.")
return None
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
if actions is not None:
raise ValueError("Error: SAC does not have action probabilities.")

observation = np.array(observation)

warnings.warn("The probabilty of taken a given action is exactly zero for a continuous distribution."
"See http://blog.christianperone.com/2019/01/ for a good explanation")
warnings.warn("Even though SAC has a Gaussian policy, it cannot return a distribution as it "
"is squashed by a tanh before being scaled and ouputed.")

return np.zeros((observation.shape[0], 1), dtype=np.float32)
return None

def predict(self, observation, state=None, mask=None, deterministic=True):
observation = np.array(observation)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ def test_model_manipulation(request, model_class):
observations = observations.reshape((-1, 1))
actions = np.array([env.action_space.sample() for _ in range(10)])

if model_class == DDPG:
if model_class in [DDPG, SAC]:
with pytest.raises(ValueError):
model.action_probability(observations, actions=actions)
else:
with pytest.warns(UserWarning):
actions_probas = model.action_probability(observations, actions=actions)
actions_probas = model.action_probability(observations, actions=actions)
assert actions_probas.shape == (len(actions), 1), actions_probas.shape
assert np.all(actions_probas == 0.0), actions_probas
assert np.all(actions_probas >= 0), actions_probas
actions_logprobas = model.action_probability(observations, actions=actions, logp=True)
assert np.allclose(actions_probas, np.exp(actions_logprobas)), (actions_probas, actions_logprobas)

# assert <15% diff
assert abs(acc_reward - loaded_acc_reward) / max(acc_reward, loaded_acc_reward) < 0.15, \
Expand Down
9 changes: 7 additions & 2 deletions tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,14 @@ def test_identity(model_name):
action, _ = model.predict(obs)
obs, reward, _, _ = env.step(action)
reward_sum += reward

assert model.action_probability(obs).shape == (1, 10), "Error: action_probability not returning correct shape"
assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \
"Error: not scalar probability"
action = env.action_space.sample()
action_prob = model.action_probability(obs, actions=action)
assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
action_logprob = model.action_probability(obs, actions=action, logp=True)
assert np.allclose(action_prob, np.exp(action_logprob)), (action_prob, action_logprob)

assert reward_sum > 0.9 * n_trials
# Free memory
del model, env
Expand Down

0 comments on commit 2bc3c87

Please sign in to comment.