Skip to content

Commit

Permalink
Add documentation about default network architecture (#1353)
Browse files Browse the repository at this point in the history
* Add documentation about default network architecture

* [ci skip] Rename custom policy section to Policy Networks
  • Loading branch information
araffin committed Mar 2, 2023
1 parent ed8783c commit f0382a2
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
23 changes: 0 additions & 23 deletions .gitlab-ci.yml

This file was deleted.

26 changes: 24 additions & 2 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _custom_policy:

Custom Policy Network
=====================
Policy Networks
===============

Stable Baselines3 provides policy networks for images (CnnPolicies),
other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies).
Expand Down Expand Up @@ -51,6 +51,28 @@ Each of these network have a features extractor followed by a fully-connected ne
.. image:: ../_static/img/sb3_policy.png


Default Network Architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The default network architecture used by SB3 depends on the algorithm and the observation space.
You can visualize the architecture by printing ``model.policy`` (see `issue #329 <https://github.com/DLR-RM/stable-baselines3/issues/329>`_).


For 1D observation space, a 2 layers fully connected net is used with:

- 64 units (per layer) for PPO/A2C/DQN
- 256 units for SAC
- [400, 300] units for TD3/DDPG (values are taken from the original TD3 paper)

For image observation spaces, the "Nature CNN" (see code for more details) is used for feature extraction, and SAC/TD3 also keeps the same fully connected network after it.
The other algorithms only have a linear layer after the CNN.
The CNN is shared between actor and critic for A2C/PPO (on-policy algorithms) to reduce computation.
Off-policy algorithms (TD3, DDPG, SAC, ...) have separate feature extractors: one for the actor and one for the critic, since the best performance is obtained with this configuration.

For mixed observations (dictionary observations), the two architectures from above are used, i.e., CNN for images and then two layers fully-connected network
(with a smaller output size for the CNN).



Custom Network Architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ Main Features
guide/algos
guide/examples
guide/vec_envs
guide/custom_env
guide/custom_policy
guide/custom_env
guide/callbacks
guide/tensorboard
guide/integrations
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Others:
- Fixed ``tests/test_vec_normalize.py`` type hint
- Fixed ``stable_baselines3/common/monitor.py`` type hint
- Added tests for StackedObservations
- Removed Gitlab CI file

Documentation:
^^^^^^^^^^^^^^
Expand All @@ -50,6 +51,7 @@ Documentation:
- Fixed typo in ``A2C`` docstring (@AlexPasqua)
- Renamed timesteps to episodes for ``log_interval`` description (@theSquaredError)
- Removed note about gif creation for Atari games (@harveybellini)
- Added information about default network architecture

Release 1.7.0 (2023-01-10)
--------------------------
Expand Down

0 comments on commit f0382a2

Please sign in to comment.