Skip to content

Commit

Permalink
Removed shared layers in mlp_extractor (#1292)
Browse files Browse the repository at this point in the history
* Modified actor-critic policies & MlpExtractor class

ActorCriticPolicy:
  - changed type hint of net_arch param: now it's a dict
  - removed check that if features extractor is not shared: no shared layers are allowed in the mlp_extractor regardless of the features extractor
ActorCriticCnnPolicy:
  - changed type hint of net_arch param: now it's a dict
MultiInputActorcriticPolicy:
  - changed type hint of net_arch param: now it's a dict
MlpExtractor:
  - changed type hint of net_arch param: now it's a dict
  - adapted networks creation
  - adapted methods: forward, forward_actor & forward_critic

* Removed shared layers in mlp_extractor

* Updated docs and changelog + reformat

* Updated custom policy tests

* Removed test on deprecation warning for share layers in mlp_extractor

Now shared layers are removed

* Update version

* Update RL Zoo doc

* Fix linter warnings

* Add ruff to Makefile (experimental)

* Add backward compat code and minor updates

* Update tests

* Add backward compatibility

* Fix test

* Improve compat code

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
AlexPasqua and araffin committed Jan 23, 2023
1 parent 69fdf15 commit b702884
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 163 deletions.
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ lint:
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics

ruff:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero --line-length 127

format:
# Sort imports
isort ${LINT_PATHS}
Expand Down
45 changes: 15 additions & 30 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
``policy_kwargs`` (both for on-policy and off-policy algorithms).


.. warning::
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared.


.. code-block:: python
import torch as th
Expand Down Expand Up @@ -242,41 +237,31 @@ On-Policy Algorithms
Custom Networks
---------------

.. warning::
Shared layers in the the ``mlp_extractor`` are **deprecated**.
In a future release all layers will have to be non-shared.
If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_).

.. warning::
In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change
to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently)
for the actor and the critic, with the same architecture.


If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``,
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>])``.

For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks,
then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``.

.. Otherwise, to have actor and critic that share the same network architecture,
.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
Otherwise, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each, this is equivalent to ``net_arch=dict(pi=[128, 128], vf=[128, 128])``).

If shared layers are needed, you need to implement a custom policy network (see `advanced example below <#advanced-example>`_).

Examples
~~~~~~~~

.. TODO(antonin): uncomment when shared network is removed
.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
..
.. .. code-block:: none
..
.. obs
.. / \
.. <128> <128>
.. | |
.. <128> <128>
.. | |
.. action value
Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``

.. code-block:: none
obs
/ \
<128> <128>
| |
<128> <128>
| |
action value
Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])``

Expand Down
19 changes: 13 additions & 6 deletions docs/guide/rl_zoo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Goals of this repository:
Installation
------------

Option 1: install the python package ``pip install rl_zoo3``

or:

1. Clone the repository:

::
Expand All @@ -42,7 +46,10 @@ Installation
::

apt-get install swig cmake ffmpeg
# full dependencies
pip install -r requirements.txt
# minimal dependencies
pip install -e .


Train an Agent
Expand All @@ -56,21 +63,21 @@ using:

::

python train.py --algo algo_name --env env_id
python -m rl_zoo3.train --algo algo_name --env env_id

For example (with evaluation and checkpoints):

::

python train.py --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000


Continue training (here, load pretrained agent for Breakout and continue
training for 5000 steps):

::

python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000


Enjoy a Trained Agent
Expand All @@ -80,13 +87,13 @@ If the trained agent exists, then you can see it in action using:

::

python enjoy.py --algo algo_name --env env_id
python -m rl_zoo3.enjoy --algo algo_name --env env_id

For example, enjoy A2C on Breakout during 5000 timesteps:

::

python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000


Hyperparameter Optimization
Expand All @@ -100,7 +107,7 @@ with a budget of 1000 trials and a maximum of 50000 steps:

::

python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
--sampler random --pruner median


Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ Changelog
==========


Release 1.8.0a1 (WIP)
Release 1.8.0a2 (WIP)
--------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)

New Features:
^^^^^^^^^^^^^
Expand Down
6 changes: 5 additions & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,11 @@ def load(
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
# backward compatibility, convert to new format
if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
saved_net_arch = data["policy_kwargs"]["net_arch"]
if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]

if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
Expand Down Expand Up @@ -726,7 +731,6 @@ def load(
)
else:
raise e

# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
Expand Down
29 changes: 24 additions & 5 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
self.observations[batch_inds],
self.actions[batch_inds],
Expand Down Expand Up @@ -603,7 +607,11 @@ def add(
self.full = True
self.pos = 0

def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
def sample(
self,
batch_size: int,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
"""
Sample elements from the replay buffer.
Expand All @@ -614,7 +622,11 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictRep
"""
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

Expand Down Expand Up @@ -743,7 +755,10 @@ def add(
if self.pos == self.buffer_size:
self.full = True

def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
def get(
self,
batch_size: Optional[int] = None,
) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
Expand All @@ -767,7 +782,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSa
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME

return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
Expand Down
37 changes: 12 additions & 25 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -452,21 +451,15 @@ def __init__(
normalize_images=normalize_images,
)

# Convert [dict()] to dict() as shared network are deprecated
if isinstance(net_arch, list) and len(net_arch) > 0:
if isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
else:
# Note: deprecation warning will be emitted
# by the MlpExtractor constructor
pass
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]

# Default network architecture, from stable-baselines
if net_arch is None:
Expand All @@ -488,12 +481,6 @@ def __init__(
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
# TODO(antonin): update the check once we change net_arch behavior
if isinstance(net_arch, list) and len(net_arch) > 0:
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)

self.log_std_init = log_std_init
dist_kwargs = None
Expand Down Expand Up @@ -770,7 +757,7 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -843,7 +830,7 @@ def __init__(
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down

0 comments on commit b702884

Please sign in to comment.