Skip to content

Commit

Permalink
Implement DDPG (#92)
Browse files Browse the repository at this point in the history
* Add DDPG + TD3 with any number of critics

* Allow any number of critics for SAC

* Update doc

* [ci skip] Update DDPG example

* Remove unused parameter

* Add DDPG to identity test

* Fix computation with n_critics=1,3

* Update doc

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update docstrings for off-policy algos

* Add check for sde

Co-authored-by: Adam Gleave <adam@gleave.me>
  • Loading branch information
araffin and AdamGleave committed Jul 16, 2020
1 parent 208890d commit 5ff176b
Show file tree
Hide file tree
Showing 23 changed files with 315 additions and 48 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ These algorithms will make it easier for the research community and industry to
Please look at the issue for more details.
Planned features:

- [ ] DDPG (you can use its successor TD3 for now)
- [ ] HER

### Planned features (v1.1+)
Expand Down Expand Up @@ -152,13 +151,16 @@ All the following examples can be executed online using Google colab notebooks:
- [Monitor Training and Plotting](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb)
- [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb)
- [RL Baselines Zoo](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb)
- [PyBullet](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb)


## Implemented Algorithms

| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
DDPG ✔️ ❌ ❌ ❌ ❌
DQN ❌ ✔️ ❌ ❌ ❌
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ❌
TD3 ✔️ ❌ ❌ ❌ ❌
DQN ❌ ✔️ ❌ ❌ ❌
============ =========== ============ ================= =============== ================


Expand Down
4 changes: 2 additions & 2 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ You can also easily define a custom architecture for the policy (or value) netwo
.. note::

Defining a custom policy class is equivalent to passing ``policy_kwargs``.
However, it lets you name the policy and so makes usually the code clearer.
``policy_kwargs`` should be rather used when doing hyperparameter search.
However, it lets you name the policy and so usually makes the code clearer.
``policy_kwargs`` is particularly useful when doing hyperparameter search.



Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ Main Features

modules/base
modules/a2c
modules/ddpg
modules/dqn
modules/ppo
modules/sac
modules/td3
modules/dqn

.. toctree::
:maxdepth: 1
Expand Down
6 changes: 5 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Pre-Release 0.8.0a3 (WIP)
Pre-Release 0.8.0a4 (WIP)
------------------------------

Breaking Changes:
Expand All @@ -12,6 +12,8 @@ Breaking Changes:
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic``
and has an additional parameter ``n_critics``
- ``SAC`` and ``TD3`` now accept an arbitrary number of critics (e.g. ``policy_kwargs=dict(n_critics=3)``)
instead of only 2 previously

New Features:
^^^^^^^^^^^^^
Expand All @@ -21,6 +23,7 @@ New Features:
when ``psutil`` is available
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
- Added ``DDPG`` algorithm as a special case of ``TD3``.
- Introduced ``BaseModel`` abstract parent for ``BasePolicy``, which critics inherit from.

Bug Fixes:
Expand All @@ -38,6 +41,7 @@ Others:
- Added ``_on_step()`` for off-policy base class
- Optimized replay buffer size by removing the need of ``next_observations`` numpy array
- Ignored errors from newer pytype version
- Added a check when using ``gSDE``

Documentation:
^^^^^^^^^^^^^^
Expand Down
104 changes: 104 additions & 0 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
.. _ddpg:

.. automodule:: stable_baselines3.ddpg


DDPG
====

`Deep Deterministic Policy Gradient (DDPG) <https://spinningup.openai.com/en/latest/algorithms/ddpg.html>`_ combines the
trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions.


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy


Notes
-----

- Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
- DDPG Paper: https://arxiv.org/abs/1509.02971
- OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html

.. note::

The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation.
to match the original paper


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python
import gym
import numpy as np
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_pendulum")
env = model.get_env()
del model # remove to demonstrate saving and loading
model = DDPG.load("ddpg_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters
----------

.. autoclass:: DDPG
:members:
:inherited-members:

.. _ddpg_policies:

DDPG Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:


.. .. autoclass:: CnnPolicy
.. :members:
.. :inherited-members:
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,5 @@ Deprecations
forkserver
cuda
Polyak
gSDE
rollouts
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ per-file-ignores =
./stable_baselines3/__init__.py:F401
./stable_baselines3/common/__init__.py:F401
./stable_baselines3/a2c/__init__.py:F401
./stable_baselines3/ddpg/__init__.py:F401
./stable_baselines3/dqn/__init__.py:F401
./stable_baselines3/ppo/__init__.py:F401
./stable_baselines3/sac/__init__.py:F401
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

from stable_baselines3.a2c import A2C
from stable_baselines3.ddpg import DDPG
from stable_baselines3.dqn import DQN
from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3
from stable_baselines3.dqn import DQN

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), 'version.txt')
Expand Down
4 changes: 4 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def __init__(self,
raise ValueError("Error: the model does not support multiple envs; it requires "
"a single vectorized environment.")

if self.use_sde and not isinstance(self.observation_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only "
"be used with continuous actions.")

def _wrap_env(self, env: GymEnv) -> VecEnv:
if not isinstance(env, VecEnv):
if self.verbose >= 1:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __repr__(self) -> str:

class OrnsteinUhlenbeckActionNoise(ActionNoise):
"""
An Ornstein Uhlenbeck action noise, this is designed to aproximate brownian motion with friction.
An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
Expand Down
9 changes: 6 additions & 3 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param batch_size: (int) Minibatch size for each gradient update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: (float) the discount factor
:param train_freq: (int) Update the model every ``train_freq`` steps.
:param gradient_steps: (int) How many gradient update after each step
:param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable.
:param gradient_steps: (int) How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/ddpg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from stable_baselines3.ddpg.ddpg import DDPG
from stable_baselines3.ddpg.policies import MlpPolicy, CnnPolicy
116 changes: 116 additions & 0 deletions stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch as th
from typing import Type, Union, Callable, Optional, Dict, Any

from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.td3.td3 import TD3
from stable_baselines3.td3.policies import TD3Policy


class DDPG(TD3):
"""
Deep Deterministic Policy Gradient (DDPG).
Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
DDPG Paper: https://arxiv.org/abs/1509.02971
Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
Note: we treat DDPG as a special case of its successor TD3.
:param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: (int) size of the replay buffer
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
:param batch_size: (int) Minibatch size for each gradient update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: (float) the discount factor
:param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable.
:param gradient_steps: (int) How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
:param seed: (int) Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""

def __init__(self, policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Callable] = 1e-3,
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 100,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: int = -1,
gradient_steps: int = -1,
n_episodes_rollout: int = 1,
action_noise: Optional[ActionNoise] = None,
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Dict[str, Any] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = 'auto',
_init_setup_model: bool = True):

super(DDPG, self).__init__(policy=policy,
env=env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau, gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
n_episodes_rollout=n_episodes_rollout,
action_noise=action_noise,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose, device=device,
create_eval_env=create_eval_env, seed=seed,
optimize_memory_usage=optimize_memory_usage,
# Remove all tricks from TD3 to obtain DDPG:
# we still need to specify target_policy_noise > 0 to avoid errors
policy_delay=1, target_noise_clip=0.0, target_policy_noise=0.1,
_init_setup_model=False)

# Use only one critic
if 'n_critics' not in self.policy_kwargs:
self.policy_kwargs['n_critics'] = 1

if _init_setup_model:
self._setup_model()

def learn(self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "DDPG",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:

return super(DDPG, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name, eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps)
2 changes: 2 additions & 0 deletions stable_baselines3/ddpg/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# DDPG can be view as a special case of TD3
from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy # noqa:F401

0 comments on commit 5ff176b

Please sign in to comment.