Skip to content

Commit

Permalink
Modified ActorCriticPolicy to support non-shared features extractor (#…
Browse files Browse the repository at this point in the history
…1148)

* Modified ActorCriticPolicy to support non-shared features extractor

* Refactored features extraction with non-shared features extractor in ActorCriticPolicy and updated doc

Doc update: added 'warning' on custom policy docs that says that, if the features extractor is non-shared, it's not possible to have shared layers in the mlp_extractor

* Moved attrib share_features_extractor in class

* Updated custom policy doc for non-shared features extractor

* Updated changelog

* Made some if-statements more readable if policies.py

The if-statements are related to the shared/non-shared features extractor in ActorCritic policies

* Simplify implementation and add run test

* Keep order in module gain to keep previous results consistents

* Fix test

* Improved docstring in policies.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Added some tests

* feature extractor -> features extractor

* Fix test

* Fix env_id in test

* Make features extractor parameter explicit

* Remove duplicate

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
4 people committed Dec 20, 2022
1 parent 8452106 commit 2cfcec4
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 47 deletions.
22 changes: 12 additions & 10 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,19 @@ using ``policy_kwargs`` parameter:
Custom Feature Extractor
^^^^^^^^^^^^^^^^^^^^^^^^

If you want to have a custom feature extractor (e.g. custom CNN when using images), you can define class
If you want to have a custom features extractor (e.g. custom CNN when using images), you can define class
that derives from ``BaseFeaturesExtractor`` and then pass it to the model when training.


.. note::

By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed by defining a custom policy for on-policy algorithms
(see `issue #1066 <https://github.com/DLR-RM/stable-baselines3/issues/1066#issuecomment-1246866844>`_
for more information) or setting ``share_features_extractor=False`` in the
``policy_kwargs`` for off-policy algorithms (and when applicable).
By default the features extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed setting ``share_features_extractor=False`` in the
``policy_kwargs`` (both for on-policy and off-policy algorithms).


.. warning::
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.


.. code-block:: python
Expand Down Expand Up @@ -174,7 +176,7 @@ Multiple Inputs and Dictionary Observations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Stable Baselines3 supports handling of multiple inputs by using ``Dict`` Gym space. This can be done using
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` feature extractor to turn multiple
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` features extractor to turn multiple
inputs into a single vector, handled by the ``net_arch`` network.

By default, ``CombinedExtractor`` processes multiple inputs as follows:
Expand All @@ -184,7 +186,7 @@ By default, ``CombinedExtractor`` processes multiple inputs as follows:
2. If input is not an image, flatten it (no layers).
3. Concatenate all previous vectors into one long vector and pass it to policy.

Much like above, you can define custom feature extractors. The following example assumes the environment has two keys in the
Much like above, you can define custom features extractors. The following example assumes the environment has two keys in the
observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple
downsampling and "vector" with a single linear layer.

Expand Down Expand Up @@ -319,7 +321,7 @@ If your task requires even more granular control over the policy/value architect
class CustomNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the feature extractor.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
Expand Down Expand Up @@ -411,7 +413,7 @@ you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256


.. note::
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).


Expand Down
14 changes: 8 additions & 6 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.7.0a7 (WIP)
Release 1.7.0a8 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -18,6 +18,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added option to have non-shared features extractor between actor and critic in on-policy algorithms (@AlexPasqua)
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations
- Features extractors now properly support unnormalized image-like observations (3D tensor)
Expand All @@ -40,6 +41,7 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``

Others:
^^^^^^^
Expand Down Expand Up @@ -685,8 +687,8 @@ Bug Fixes:
- Fix model creation initializing CUDA even when `device="cpu"` is provided
- Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88)
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
- Fixed feature extractor bug for target network where the same net was shared instead
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor)
- Fixed features extractor bug for target network where the same net was shared instead
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom features extractor)
- Fixed a bug when passing an environment when loading a saved model with a ``CnnPolicy``, the passed env was not wrapped properly
(the bug was introduced when implementing ``HER`` so it should not be present in previous versions)

Expand Down Expand Up @@ -763,7 +765,7 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added ``StopTrainingOnMaxEpisodes`` details and example (@xicocaio)
- Updated custom policy section (added custom feature extractor example)
- Updated custom policy section (added custom features extractor example)
- Re-enable ``sphinx_autodoc_typehints``
- Updated doc style for type hints and remove duplicated type hints

Expand Down Expand Up @@ -801,7 +803,7 @@ Bug Fixes:
- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang)
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
- Fixed DQN target network sharing feature extractor with the main network.
- Fixed DQN target network sharing features extractor with the main network.
- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12)
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.

Expand Down Expand Up @@ -841,7 +843,7 @@ Breaking Changes:
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``
- Created new file common/torch_layers.py, similar to SB refactoring

- Contains all PyTorch network layer definitions and feature extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
- Contains all PyTorch network layer definitions and features extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``

- Renamed ``BaseRLModel`` to ``BaseAlgorithm`` (along with offpolicy and onpolicy variants)
- Moved on-policy and off-policy base algorithms to ``common/on_policy_algorithm.py`` and ``common/off_policy_algorithm.py``, respectively.
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
warnings.warn(
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
"You might need to use a custom feature extractor "
"You might need to use a custom features extractor "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
)

Expand Down
96 changes: 80 additions & 16 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -117,16 +118,28 @@ def make_features_extractor(self) -> BaseFeaturesExtractor:
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)

def extract_features(self, obs: th.Tensor) -> th.Tensor:
def extract_features(self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs:
:return:
:param obs: The observation
:param features_extractor: The features extractor to use. If it is set to None,
the features extractor of the policy is used.
:return: The features
"""
assert self.features_extractor is not None, "No features extractor was set"
if features_extractor is None:
warnings.warn(
(
"When calling extract_features(), you should explicitely pass a features_extractor as parameter. "
"This will be mandatory in Stable-Baselines v1.8.0"
),
DeprecationWarning,
)

features_extractor = features_extractor or self.features_extractor
assert features_extractor is not None, "No features extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)
return features_extractor(preprocessed_obs)

def _get_constructor_parameters(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -391,6 +404,7 @@ class ActorCriticPolicy(BasePolicy):
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
Expand All @@ -414,6 +428,7 @@ def __init__(
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -447,8 +462,20 @@ def __init__(
self.activation_fn = activation_fn
self.ortho_init = ortho_init

self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)

self.log_std_init = log_std_init
dist_kwargs = None
Expand Down Expand Up @@ -555,6 +582,13 @@ def _build(self, lr_schedule: Schedule) -> None:
self.action_net: 0.01,
self.value_net: 1,
}
if not self.share_features_extractor:
# Note(antonin): this is to keep SB3 results
# consistent, see GH#1148
del module_gains[self.features_extractor]
module_gains[self.pi_features_extractor] = np.sqrt(2)
module_gains[self.vf_features_extractor] = np.sqrt(2)

for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))

Expand All @@ -571,7 +605,12 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
Expand All @@ -580,6 +619,20 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob

def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:return: the output of the features extractor(s)
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor)
else:
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features

def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
Expand Down Expand Up @@ -620,14 +673,19 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso
Evaluate actions according to the current policy,
given the observations.
:param obs:
:param actions:
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
latent_pi, latent_vf = self.mlp_extractor(features)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
Expand All @@ -641,18 +699,18 @@ def get_distribution(self, obs: th.Tensor) -> Distribution:
:param obs:
:return: the action distribution.
"""
features = self.extract_features(obs)
features = super().extract_features(obs, self.pi_features_extractor)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)

def predict_values(self, obs: th.Tensor) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs:
:param obs: Observation
:return: the estimated values.
"""
features = self.extract_features(obs)
features = super().extract_features(obs, self.vf_features_extractor)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)

Expand Down Expand Up @@ -680,6 +738,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
Expand All @@ -703,6 +762,7 @@ def __init__(
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -721,6 +781,7 @@ def __init__(
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
Expand Down Expand Up @@ -749,7 +810,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the feature extractor.
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
Expand All @@ -773,6 +835,7 @@ def __init__(
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -791,6 +854,7 @@ def __init__(
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
Expand Down Expand Up @@ -858,7 +922,7 @@ def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)

Expand All @@ -869,5 +933,5 @@ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs)
features = self.extract_features(obs, self.features_extractor)
return self.q_networks[0](th.cat([features, actions], dim=1))

0 comments on commit 2cfcec4

Please sign in to comment.