Skip to content

Commit

Permalink
Deprecation of shared layers in MlpExtractor (#1252)
Browse files Browse the repository at this point in the history
* Deprecation warning for shared layers in Mlpextractor

* Updated changelog

* Updated custom policy doc

* Update doc and deprecation

* Fix doc build

* Minor edits

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
AlexPasqua and araffin committed Jan 5, 2023
1 parent 4fa17dc commit 30a1984
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 107 deletions.
110 changes: 46 additions & 64 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ Each of these network have a features extractor followed by a fully-connected ne
.. image:: ../_static/img/sb3_policy.png


.. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif

Custom Network Architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -90,13 +88,13 @@ using ``policy_kwargs`` parameter:
# of two layers of size 32 each with Relu activation function
# Note: an extra linear layer will be added on top of the pi and the vf nets, respectively
policy_kwargs = dict(activation_fn=th.nn.ReLU,
net_arch=[dict(pi=[32, 32], vf=[32, 32])])
net_arch=dict(pi=[32, 32], vf=[32, 32]))
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=100000)
model.learn(total_timesteps=20_000)
# Save the agent
model.save("ppo_cartpole")
Expand All @@ -114,13 +112,14 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t

.. note::

By default the features extractor is shared between the actor and the critic to save computation (when applicable).
For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable).
However, this can be changed setting ``share_features_extractor=False`` in the
``policy_kwargs`` (both for on-policy and off-policy algorithms).


.. warning::
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared.


.. code-block:: python
Expand Down Expand Up @@ -240,64 +239,56 @@ downsampling and "vector" with a single linear layer.
On-Policy Algorithms
^^^^^^^^^^^^^^^^^^^^

Shared Networks
Custom Networks
---------------

The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
structure:

1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
If the number of ints is zero, there will be no shared layers.
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.

In short: ``[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]``.

Examples
~~~~~~~~

Two shared layers of size 128: ``net_arch=[128, 128]``


.. code-block:: none
.. warning::
Shared layers in the the ``mlp_extractor`` are **deprecated**.
In a future release all layers will have to be non-shared.
If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_).

obs
|
<128>
|
<128>
/ \
action value
.. warning::
In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change
to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently)
for the actor and the critic, with the same architecture.


Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])]``
If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``,
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>])``.

.. code-block:: none
For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks,
then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``.

obs
|
<128>
/ \
action <256>
|
<256>
|
value
.. Otherwise, to have actor and critic that share the same network architecture,
.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
Examples
~~~~~~~~

Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
.. TODO(antonin): uncomment when shared network is removed
.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
..
.. .. code-block:: none
..
.. obs
.. / \
.. <128> <128>
.. | |
.. <128> <128>
.. | |
.. action value
Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])``

.. code-block:: none
obs
|
<128>
/ \
<16> <256>
| |
action value
obs
/ \
<32> <64>
| |
<32> <64>
| |
action value
Advanced Example
Expand Down Expand Up @@ -334,7 +325,7 @@ If your task requires even more granular control over the policy/value architect
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super(CustomNetwork, self).__init__()
super().__init__()
# IMPORTANT:
# Save output dimensions, used to create the distributions
Expand Down Expand Up @@ -370,8 +361,6 @@ If your task requires even more granular control over the policy/value architect
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
*args,
**kwargs,
):
Expand All @@ -380,8 +369,6 @@ If your task requires even more granular control over the policy/value architect
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
Expand All @@ -402,21 +389,16 @@ If your task requires even more granular control over the policy/value architect
Off-Policy Algorithms
^^^^^^^^^^^^^^^^^^^^^

If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
you can pass a dictionary of the following structure: ``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG``, ``TQC`` or ``TD3``,
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], qf=[<critic network architecture>])``.

For example, if you want a different architecture for the actor (aka ``pi``) and the critic (Q-function aka ``qf``) networks,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
then you can specify ``net_arch=dict(pi=[64, 64], qf=[400, 300])``.

Otherwise, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).


.. note::
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).


.. note::
For advanced customization of off-policy algorithms policies, please take a look at the code.
A good understanding of the algorithm used is required, see discussion in `issue #425 <https://github.com/DLR-RM/stable-baselines3/issues/425>`_
Expand Down
48 changes: 34 additions & 14 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@ Changelog
==========


Release 1.7.0a11 (WIP)
Release 1.7.0a12 (WIP)
--------------------------

.. warning::

Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.


.. note::

A2C and PPO saved with SB3 < 1.7.0 will show a warning about
Expand Down Expand Up @@ -34,8 +41,15 @@ New Features:
- Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor``
- Added support for Python 3.10

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``

`RL Zoo`_
^^^^^^^^^
- Added support for python file for configuration
- Added ``monitor_kwargs`` parameter

Bug Fixes:
^^^^^^^^^^
Expand All @@ -52,6 +66,7 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)

Others:
^^^^^^^
Expand Down Expand Up @@ -99,8 +114,12 @@ New Features:
- Added progress bar callback
- The `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ can now be installed as a package (``pip install rl_zoo3``)

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^
- RL Zoo is now a python package and can be installed using ``pip install rl_zoo3``

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -135,8 +154,8 @@ New Features:
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel)

Bug Fixes:
Expand Down Expand Up @@ -192,8 +211,8 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53


Expand Down Expand Up @@ -246,8 +265,8 @@ New Features:
depending on desired maximum width of output.
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
- coming soon: Cross Entropy Method, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/62

Bug Fixes:
Expand Down Expand Up @@ -309,8 +328,8 @@ New Features:
- Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong
- Added ``copy()`` and ``combine()`` methods to ``RunningMeanStd``

SB3-Contrib
^^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added Trust Region Policy Optimization (TRPO) (@cyprienc)
- Added Augmented Random Search (ARS) (@sgillen)
- Coming soon: PPO LSTM, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
Expand Down Expand Up @@ -1137,7 +1156,8 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).
.. _Quentin Gallouédec: https://gallouedec.com/
.. _@qgallouedec: https://github.com/qgallouedec


.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo

Contributors:
-------------
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def set_parameters(
f"expected {objects_needing_update}, got {updated_objects}"
)

@classmethod
@classmethod # noqa: C901
def load(
cls: Type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/envs/identity_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from gym import spaces

from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
from stable_baselines3.common.type_aliases import GymStepReturn

T = TypeVar("T", int, np.ndarray)

Expand Down
28 changes: 23 additions & 5 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -451,12 +452,28 @@ def __init__(
normalize_images=normalize_images,
)

# Convert [dict()] to dict() as shared network are deprecated
if isinstance(net_arch, list) and len(net_arch) > 0:
if isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
else:
# Note: deprecation warning will be emitted
# by the MlpExtractor constructor
pass

# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
net_arch = dict(pi=[64, 64], vf=[64, 64])

self.net_arch = net_arch
self.activation_fn = activation_fn
Expand All @@ -472,7 +489,8 @@ def __init__(
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
# TODO(antonin): update the check once we change net_arch behavior
if isinstance(net_arch, list) and len(net_arch) > 0:
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
Expand Down Expand Up @@ -752,7 +770,7 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -825,7 +843,7 @@ def __init__(
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down

0 comments on commit 30a1984

Please sign in to comment.