Skip to content

Commit

Permalink
Merge branch 'master' into sde
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 24, 2020
2 parents f068ada + 9b42b97 commit cc6794b
Show file tree
Hide file tree
Showing 23 changed files with 312 additions and 100 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
pip install .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Build the doc
run: |
make doc
- name: Type check
run: |
make type
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ All the following examples can be executed online using Google colab notebooks:

| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |

Expand Down Expand Up @@ -232,5 +232,11 @@ Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hi
To any interested in making the baselines better, there is still some documentation that needs to be done.
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.

## Acknowledgments

The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*.

The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).


Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
2 changes: 1 addition & 1 deletion docs/common/cmd_utils.rst → docs/common/cmd_util.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. _cmd_utils:
.. _cmd_util:

Command Utils
=========================
Expand Down
16 changes: 8 additions & 8 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.


============ =========== ============ ================
Name ``Box`` ``Discrete`` Multi Processing
============ =========== ============ ================
A2C ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌
TD3 ✔️ ❌ ❌
============ =========== ============ ================
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ❌
TD3 ✔️ ❌ ❌ ❌ ❌
============ =========== ============ ================= =============== ================


.. note::
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Try it online with Colab Notebooks!
All the following examples can be executed online using Google colab |colab|
notebooks:

- `Full Tutorial <https://github.com/araffin/rl-tutorial-jnrr19>`_
- `Full Tutorial <https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3>`_
- `All Notebooks <https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3>`_
- `Getting Started`_
- `Training, Saving, Loading`_
Expand Down
4 changes: 2 additions & 2 deletions docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ General advice when using Reinforcement Learning
TL;DR
-----

1. Read about RL and Stable Baselines
1. Read about RL and Stable Baselines3
2. Do quantitative experiments and hyperparameter tuning if needed
3. Evaluate the performance using a separate test environment
4. For better performance, increase the training budget


Like any other subject, if you want to work with RL, you should first read about it (we have a dedicated `resource page <rl.html>`_ to get you started)
to understand what you are using. We also recommend you read Stable Baselines (SB) documentation and do the `tutorial <https://github.com/araffin/rl-tutorial-jnrr19>`_.
to understand what you are using. We also recommend you read Stable Baselines3 (SB3) documentation and do the `tutorial <https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3>`_.
It covers basic usage and guide you towards more advanced concepts of the library (e.g. callbacks and wrappers).

Reinforcement Learning differs from other machine learning methods in several ways. The data used to train the agent is collected
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Main Features
:caption: Common

common/atari_wrappers
common/cmd_utils
common/cmd_util
common/distributions
common/evaluation
common/env_checker
Expand Down
10 changes: 6 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
Changelog
==========

Pre-Release 0.6.0a8 (WIP)
Pre-Release 0.6.0a10 (WIP)
------------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^

Expand All @@ -15,13 +14,16 @@ New Features:
- Added env checker (Sync with Stable Baselines)
- Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines)
- Added determinism tests
- Added ``cmd_utils`` and ``atari_wrappers``
- Added ``cmd_util`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)

Bug Fixes:
^^^^^^^^^^
- Fixed a bug that prevented model trained on cpu to be loaded on gpu
- Fixed version number that had a new line included
- Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size
- Fixed ``sde_sample_freq`` that was not taken into account for SAC

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -226,4 +228,4 @@ And all the contributors:
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur
8 changes: 4 additions & 4 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========


Expand All @@ -46,7 +46,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
from stable_baselines3 import A2C
from stable_baselines3.a2c import MlpPolicy
from stable_baselines3.common.cmd_utils import make_vec_env
from stable_baselines3.common.cmd_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
Expand Down
8 changes: 4 additions & 4 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========

Example
Expand All @@ -55,7 +55,7 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
from stable_baselines3 import A2C
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.cmd_utils import make_vec_env
from stable_baselines3.common.cmd_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
Expand Down
6 changes: 3 additions & 3 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def collect_rollouts(self,

while not done:

if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.actor.reset_noise()

Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BaseBuffer(object):
to which the values will be converted
:param n_envs: (int) Number of parallel environments
"""

def __init__(self,
buffer_size: int,
observation_space: spaces.Space,
Expand Down
123 changes: 115 additions & 8 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Tuple, Dict, Any

from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import Normal, Categorical, Bernoulli
from gym import spaces

from stable_baselines3.common.preprocessing import get_action_dim
Expand Down Expand Up @@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
:return: (th.Tensor) shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(axis=1)
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
Expand Down Expand Up @@ -292,6 +291,114 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)


class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: (List[int]) List of sizes of discrete action spaces
"""

def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
self.action_dims = action_dims
self.distributions = None

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""

action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self

def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)

def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)

def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)


class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: (int) Number of binary actions
"""

def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
self.distribution = None
self.action_dims = action_dims

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
self.distribution = Bernoulli(logits=action_logits)
return self

def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)

def sample(self) -> th.Tensor:
return self.distribution.sample()

def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)


class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Expand Down Expand Up @@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space,
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiBinary):
# return BernoulliDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
Expand Down

0 comments on commit cc6794b

Please sign in to comment.