Skip to content

Commit

Permalink
Update TD3/DDPG/DQN defaults for consistency (#1785)
Browse files Browse the repository at this point in the history
* Update TD3/DDPG/DQN defaults for consistency

* Update changelog
  • Loading branch information
araffin committed Jan 12, 2024
1 parent a653aec commit a9273f9
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
28 changes: 26 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,36 @@
Changelog
==========


Release 2.3.0a0 (WIP)
Release 2.3.0a1 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC``

.. code-block:: python
# SB3 < 2.3.0 default hyperparameters
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
# SB3 >= 2.3.0:
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)
.. note::

Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3>`_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx>`_)



- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms


.. code-block:: python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_start=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_start=100)
New Features:
^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def __init__(
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 50000,
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def __init__(
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0a0
2.3.0a1

0 comments on commit a9273f9

Please sign in to comment.