Skip to content

Commit

Permalink
Add argument features_extractor to `ActorCriticPolicy.extract_featu…
Browse files Browse the repository at this point in the history
…res` (#1710)

* add argument to extract_features

* remove empty lines

* changelog and version
  • Loading branch information
qgallouedec committed Oct 9, 2023
1 parent c6c660e commit c6bf251
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a6 (WIP)
Release 2.2.0a7 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -52,6 +52,7 @@ Others:
- Fixed ``stable_baselines3/common/buffers.py`` type hints
- Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument

Documentation:
^^^^^^^^^^^^^^
Expand Down
22 changes: 16 additions & 6 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtra
"""
Preprocess the observation if needed and extract features.
:param obs: The observation
:param features_extractor: The features extractor to use.
:return: The extracted features
:param obs: Observation
:param features_extractor: The features extractor to use.
:return: The extracted features
"""
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return features_extractor(preprocessed_obs)
Expand Down Expand Up @@ -642,16 +642,26 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
actions = actions.reshape((-1, *self.action_space.shape))
return actions, values, log_prob

def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:return: the output of the features extractor(s)
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
:return: The extracted features. If features extractor is not shared, returns a tuple with the
features for the actor and the features for the critic.
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor)
return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
else:
if features_extractor is not None:
warnings.warn(
"Provided features_extractor will be ignored because the features extractor is not shared.",
UserWarning,
)

pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a6
2.2.0a7

0 comments on commit c6bf251

Please sign in to comment.