Skip to content

Commit

Permalink
Support for MultiBinary / MultiDiscrete spaces (#13)
Browse files Browse the repository at this point in the history
* multicategorical dist and test

* fixed List annotation

* bernoulli dist and test

* added distributions to preprocessing (needs testing)

* fixed and tested distributions

* added changelog and fixed ppo policy

* minor fix

* dist fixes, added test_spaces

* clean up

* modified changelog

* additional fixes

* minor changelog mod

* hot encoding fix, flake8 clean up

* lint tests

* preprocessing fix

* fixed bernoulli bug

* removed commented prints

* Update changelog.rst

* included suggested modifications

* linting fix

* increased space dim

* Update doc and tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
rolandgvc and araffin committed May 18, 2020
1 parent 15ff6d4 commit 91adefd
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 91 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks:

| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |

Expand Down
16 changes: 8 additions & 8 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.


============ =========== ============ ================
Name ``Box`` ``Discrete`` Multi Processing
============ =========== ============ ================
A2C ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌
TD3 ✔️ ❌ ❌
============ =========== ============ ================
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ❌
TD3 ✔️ ❌ ❌ ❌ ❌
============ =========== ============ ================= =============== ================


.. note::
Expand Down
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
Changelog
==========

Pre-Release 0.6.0a8 (WIP)
Pre-Release 0.6.0a9 (WIP)
------------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Remove State-Dependent Exploration (SDE) support for ``TD3``
Expand All @@ -17,6 +16,8 @@ New Features:
- Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines)
- Added determinism tests
- Added ``cmd_utils`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -227,4 +228,4 @@ And all the contributors:
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc
6 changes: 3 additions & 3 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========

Example
Expand Down
6 changes: 3 additions & 3 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BaseBuffer(object):
to which the values will be converted
:param n_envs: (int) Number of parallel environments
"""

def __init__(self,
buffer_size: int,
observation_space: spaces.Space,
Expand Down
123 changes: 115 additions & 8 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Tuple, Dict, Any

from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import Normal, Categorical, Bernoulli
from gym import spaces

from stable_baselines3.common.preprocessing import get_action_dim
Expand Down Expand Up @@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
:return: (th.Tensor) shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(axis=1)
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
Expand Down Expand Up @@ -292,6 +291,114 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)


class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: (List[int]) List of sizes of discrete action spaces
"""

def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
self.action_dims = action_dims
self.distributions = None

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""

action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self

def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)

def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)

def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)


class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: (int) Number of binary actions
"""

def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
self.distribution = None
self.action_dims = action_dims

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
self.distribution = Bernoulli(logits=action_logits)
return self

def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)

def sample(self) -> th.Tensor:
return self.distribution.sample()

def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)


class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Expand Down Expand Up @@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space,
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiBinary):
# return BernoulliDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
Expand Down
71 changes: 33 additions & 38 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def predict(self, observation: np.ndarray,
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape or
observation.shape[1:] == self.observation_space.shape):
if (observation.shape == self.observation_space.shape
or observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
Expand Down Expand Up @@ -279,40 +279,40 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
# elif isinstance(observation_space, gym.spaces.MultiBinary):
# if observation.shape == (observation_space.n,):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
# "environment, please use ({},) or ".format(observation_space.n) +
# "(n_env, {}) for the observation shape.".format(observation_space.n))
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")

elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"(n_env, {observation_space.n}) for the observation shape.")
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
raise ValueError("Error: Cannot determine if the observation is vectorized "
+ f" with the space type {observation_space}.")

def _get_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise ValueError(f"Error: unknown policy type {name},"
"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
return _policy_registry[base_policy_type][name]


Expand All @@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None:
:param policy: (Type[BasePolicy]) the policy class
"""
sub_class = None
# For building the doc
try:
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
except AttributeError:
sub_class = str(th.random.randint(100))
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")

Expand Down Expand Up @@ -511,7 +507,6 @@ def __init__(self, feature_dim: int,
device: Union[th.device, str] = 'auto'):
super(MlpExtractor, self).__init__()
device = get_device(device)

shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
Expand Down

0 comments on commit 91adefd

Please sign in to comment.