-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DDPG + TD3 with any number of critics * Allow any number of critics for SAC * Update doc * [ci skip] Update DDPG example * Remove unused parameter * Add DDPG to identity test * Fix computation with n_critics=1,3 * Update doc * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docstrings for off-policy algos * Add check for sde Co-authored-by: Adam Gleave <adam@gleave.me>
- Loading branch information
1 parent
208890d
commit 5ff176b
Showing
23 changed files
with
315 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
.. _ddpg: | ||
|
||
.. automodule:: stable_baselines3.ddpg | ||
|
||
|
||
DDPG | ||
==== | ||
|
||
`Deep Deterministic Policy Gradient (DDPG) <https://spinningup.openai.com/en/latest/algorithms/ddpg.html>`_ combines the | ||
trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions. | ||
|
||
|
||
.. rubric:: Available Policies | ||
|
||
.. autosummary:: | ||
:nosignatures: | ||
|
||
MlpPolicy | ||
|
||
|
||
Notes | ||
----- | ||
|
||
- Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf | ||
- DDPG Paper: https://arxiv.org/abs/1509.02971 | ||
- OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html | ||
|
||
.. note:: | ||
|
||
The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation. | ||
to match the original paper | ||
|
||
|
||
Can I use? | ||
---------- | ||
|
||
- Recurrent policies: ❌ | ||
- Multi processing: ❌ | ||
- Gym spaces: | ||
|
||
|
||
============= ====== =========== | ||
Space Action Observation | ||
============= ====== =========== | ||
Discrete ❌ ✔️ | ||
Box ✔️ ✔️ | ||
MultiDiscrete ❌ ✔️ | ||
MultiBinary ❌ ✔️ | ||
============= ====== =========== | ||
|
||
|
||
Example | ||
------- | ||
|
||
.. code-block:: python | ||
import gym | ||
import numpy as np | ||
from stable_baselines3 import DDPG | ||
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise | ||
env = gym.make('Pendulum-v0') | ||
# The noise objects for DDPG | ||
n_actions = env.action_space.shape[-1] | ||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) | ||
model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1) | ||
model.learn(total_timesteps=10000, log_interval=10) | ||
model.save("ddpg_pendulum") | ||
env = model.get_env() | ||
del model # remove to demonstrate saving and loading | ||
model = DDPG.load("ddpg_pendulum") | ||
obs = env.reset() | ||
while True: | ||
action, _states = model.predict(obs) | ||
obs, rewards, dones, info = env.step(action) | ||
env.render() | ||
Parameters | ||
---------- | ||
|
||
.. autoclass:: DDPG | ||
:members: | ||
:inherited-members: | ||
|
||
.. _ddpg_policies: | ||
|
||
DDPG Policies | ||
------------- | ||
|
||
.. autoclass:: MlpPolicy | ||
:members: | ||
:inherited-members: | ||
|
||
|
||
.. .. autoclass:: CnnPolicy | ||
.. :members: | ||
.. :inherited-members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,3 +117,5 @@ Deprecations | |
forkserver | ||
cuda | ||
Polyak | ||
gSDE | ||
rollouts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from stable_baselines3.ddpg.ddpg import DDPG | ||
from stable_baselines3.ddpg.policies import MlpPolicy, CnnPolicy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch as th | ||
from typing import Type, Union, Callable, Optional, Dict, Any | ||
|
||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm | ||
from stable_baselines3.common.noise import ActionNoise | ||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback | ||
from stable_baselines3.td3.td3 import TD3 | ||
from stable_baselines3.td3.policies import TD3Policy | ||
|
||
|
||
class DDPG(TD3): | ||
""" | ||
Deep Deterministic Policy Gradient (DDPG). | ||
Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf | ||
DDPG Paper: https://arxiv.org/abs/1509.02971 | ||
Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html | ||
Note: we treat DDPG as a special case of its successor TD3. | ||
:param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) | ||
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str) | ||
:param learning_rate: (float or callable) learning rate for adam optimizer, | ||
the same learning rate will be used for all networks (Q-Values, Actor and Value function) | ||
it can be a function of the current progress remaining (from 1 to 0) | ||
:param buffer_size: (int) size of the replay buffer | ||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts | ||
:param batch_size: (int) Minibatch size for each gradient update | ||
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1) | ||
:param gamma: (float) the discount factor | ||
:param train_freq: (int) Update the model every ``train_freq`` steps. Set to `-1` to disable. | ||
:param gradient_steps: (int) How many gradient steps to do after each rollout | ||
(see ``train_freq`` and ``n_episodes_rollout``) | ||
Set to ``-1`` means to do as many gradient steps as steps done in the environment | ||
during the rollout. | ||
:param n_episodes_rollout: (int) Update the model every ``n_episodes_rollout`` episodes. | ||
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. | ||
:param action_noise: (ActionNoise) the action noise type (None by default), this can help | ||
for hard exploration problem. Cf common.noise for the different action noise type. | ||
:param optimize_memory_usage: (bool) Enable a memory efficient variant of the replay buffer | ||
at a cost of more complexity. | ||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 | ||
:param create_eval_env: (bool) Whether to create a second environment that will be | ||
used for evaluating the agent periodically. (Only available when passing string for the environment) | ||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation | ||
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug | ||
:param seed: (int) Seed for the pseudo random generators | ||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. | ||
Setting it to auto, the code will be run on the GPU if possible. | ||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance | ||
""" | ||
|
||
def __init__(self, policy: Union[str, Type[TD3Policy]], | ||
env: Union[GymEnv, str], | ||
learning_rate: Union[float, Callable] = 1e-3, | ||
buffer_size: int = int(1e6), | ||
learning_starts: int = 100, | ||
batch_size: int = 100, | ||
tau: float = 0.005, | ||
gamma: float = 0.99, | ||
train_freq: int = -1, | ||
gradient_steps: int = -1, | ||
n_episodes_rollout: int = 1, | ||
action_noise: Optional[ActionNoise] = None, | ||
optimize_memory_usage: bool = False, | ||
tensorboard_log: Optional[str] = None, | ||
create_eval_env: bool = False, | ||
policy_kwargs: Dict[str, Any] = None, | ||
verbose: int = 0, | ||
seed: Optional[int] = None, | ||
device: Union[th.device, str] = 'auto', | ||
_init_setup_model: bool = True): | ||
|
||
super(DDPG, self).__init__(policy=policy, | ||
env=env, | ||
learning_rate=learning_rate, | ||
buffer_size=buffer_size, | ||
learning_starts=learning_starts, | ||
batch_size=batch_size, | ||
tau=tau, gamma=gamma, | ||
train_freq=train_freq, | ||
gradient_steps=gradient_steps, | ||
n_episodes_rollout=n_episodes_rollout, | ||
action_noise=action_noise, | ||
policy_kwargs=policy_kwargs, | ||
tensorboard_log=tensorboard_log, | ||
verbose=verbose, device=device, | ||
create_eval_env=create_eval_env, seed=seed, | ||
optimize_memory_usage=optimize_memory_usage, | ||
# Remove all tricks from TD3 to obtain DDPG: | ||
# we still need to specify target_policy_noise > 0 to avoid errors | ||
policy_delay=1, target_noise_clip=0.0, target_policy_noise=0.1, | ||
_init_setup_model=False) | ||
|
||
# Use only one critic | ||
if 'n_critics' not in self.policy_kwargs: | ||
self.policy_kwargs['n_critics'] = 1 | ||
|
||
if _init_setup_model: | ||
self._setup_model() | ||
|
||
def learn(self, | ||
total_timesteps: int, | ||
callback: MaybeCallback = None, | ||
log_interval: int = 4, | ||
eval_env: Optional[GymEnv] = None, | ||
eval_freq: int = -1, | ||
n_eval_episodes: int = 5, | ||
tb_log_name: str = "DDPG", | ||
eval_log_path: Optional[str] = None, | ||
reset_num_timesteps: bool = True) -> OffPolicyAlgorithm: | ||
|
||
return super(DDPG, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, | ||
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, | ||
tb_log_name=tb_log_name, eval_log_path=eval_log_path, | ||
reset_num_timesteps=reset_num_timesteps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# DDPG can be view as a special case of TD3 | ||
from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy # noqa:F401 |
Oops, something went wrong.