Skip to content

Commit

Permalink
Multiprocessing support for HerReplayBuffer (#704)
Browse files Browse the repository at this point in the history
* IM compat. modif from old fork

* mp her working, without offline sampling

* update readme and doc

* fix discrete action/obs space case

* handle offline sampling

* fix pos to be consistent with the old version

* improve typing and docstring

* fix discrete obs special case

* new her, using episode uid

* deal with full buffer

* offline not implemented

* info storage; compute_reward as arg; offline sampling error

* offline sampling; timeout_termination; fix last_trans detection

* rm max_episode_length from tests

* fix loading and loading test

* Fix episode sampling strategy

* Episode interrupted not valid

* Typo

* Fix infos sampling, next_obs desired goals, offline sampling

* update tests for multienvs

* speed up code

* handle timeout sampling when samping

* give up ep_uid for ep_start and ep_lenght

* speed up sampling

* Improve docstring

* Typos and renaming

* Fix typing

* Fix linter warnings

* Renaming + add note

* fix reward type

* Fix future sampling strategy

* Fix future goal selection strategy

* env_fn as lambda

* Re-fix linter warnings

* Formatting

* Fix offline sampling

* restore the initial performance budget

* Remove max_episode_length for HerReplayBuffer kwargs

* SubprcVecEnv compat test

* Dedicated SubrocVecEnv test rm n_envs from parametrization

* Back to using the env arg instead of compute_reward

* Up VecEnv import

* fix lint warnings

* fix docstring

* Fix device issue

* actor_loss_modifier in SAV and TD3

* Merge RewardModifier and ActorLossModifier into Surgeon

* update surgeon for rnd

* fix uninteded merge

* fix uninteded merge

* fix unintended merge

* Rm unintended merge

* Fix KeyError

* Remove useless `all_inds`

* Minor docstring format

* Fix hint

* speedup!

* Speedup again

* speedup

* np.nonzero

* fix env normalization

* flat sampling for speedup

* typo

* drop online

* format

* remove observation from env_cheker (see #1335)

* update changelog

* default device to "auto"

* add comment for info storage

* add comment for ep_start and ep_length attributes

* a[b][c] to a[b, c]

* comment flatnonzero and unravel_index

* update _sample_goals docstring

* Fix future gaol sampling for split episode

* add informative error message for learning_starts too small

* use keyword arg for env

* try fix pytye

* Update stable_baselines3/common/off_policy_algorithm.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Add `copy_info_dict` option

* Ignore pytype

* Update changelog

* Rename variables and improve documentation

* Ignore new bug bear rule

* Add note about future strategy

* Add deprecation warning

* Fix bug trying to pickle buffer kwargs

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Mar 20, 2023
1 parent e5deeed commit c5adad8
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 627 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ All the following examples can be executed online using Google Colab notebooks:
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN<sup>[1](#f1)</sup> | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| RecurrentPPO<sup>[1](#f1)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARS [#f1]_ ✔️ ✔️ ❌ ❌
A2C ✔️ ✔️ ✔️ ✔️ ✔️
DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌
HER ✔️ ✔️ ❌ ❌ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN [#f1]_ ❌ ️ ✔️ ❌ ❌ ✔️
RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️
Expand Down
4 changes: 0 additions & 4 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,6 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
replay_buffer_kwargs=dict(
n_sampled_goal=n_sampled_goal,
goal_selection_strategy="future",
# IMPORTANT: because the env is not wrapped with a TimeLimit wrapper
# we have to manually specify the max number of steps per episode
max_episode_length=100,
online_sampling=True,
),
verbose=1,
buffer_size=int(1e6),
Expand Down
4 changes: 1 addition & 3 deletions docs/guide/migration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,8 @@ Despite this change, no change in performance should be expected.
HER
^^^

The ``HER`` implementation now also supports online sampling of the new goals. This is done in a vectorized version.
The ``HER`` implementation now only supports online sampling of the new goals. This is done in a vectorized version.
The goal selection strategy ``RANDOM`` is no longer supported.
``HER`` now supports ``VecNormalize`` wrapper but only when ``online_sampling=True``.
For performance reasons, the maximum number of steps per episodes must be specified (see :ref:`HER <her>` documentation).


New logger API
Expand Down
8 changes: 7 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a9 (WIP)
Release 1.8.0a10 (WIP)
--------------------------

.. warning::
Expand All @@ -20,12 +20,18 @@ Breaking Changes:
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
- Dropped offline sampling for ``HerReplayBuffer``
- As ``HerReplayBuffer`` was refactored to support multiprocessing, previous replay buffer are incompatible with this new version
- ``HerReplayBuffer`` doesn't require a ``max_episode_length`` anymore

New Features:
^^^^^^^^^^^^^
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)
- Added multiprocessing support for ``HerReplayBuffer``
- ``HerReplayBuffer`` now supports all datatypes supported by ``ReplayBuffer``


`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
21 changes: 6 additions & 15 deletions docs/modules/her.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,19 @@ It creates "virtual" transitions by relabeling transitions (changing the desired
- a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal``


.. warning::

For performance reasons, the maximum number of steps per episodes must be specified.
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None).
Otherwise, you can directly pass ``max_episode_length`` to the model constructor


.. warning::

Because it needs access to ``env.compute_reward()``
``HER`` must be loaded with the env. If you just want to use the trained policy
without instantiating the environment, we recommend saving the policy only.


.. note::

Compared to other implementations, the ``future`` goal sampling strategy is inclusive:
the current transition can be used when re-sampling.


Notes
-----

Expand Down Expand Up @@ -77,11 +75,6 @@ This example is only to demonstrate the use of the library and its functions, an
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
# If True the HER transitions will get sampled online
online_sampling = True
# Time limit for the episodes
max_episode_length = N_BITS
# Initialize the model
model = model_class(
"MultiInputPolicy",
Expand All @@ -91,8 +84,6 @@ This example is only to demonstrate the use of the library and its functions, an
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy=goal_selection_strategy,
online_sampling=online_sampling,
max_episode_length=max_episode_length,
),
verbose=1,
)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ line-length = 127
target-version = "py37"
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
ignore = []
# Ignore explicit stacklevel`
ignore = ["B028"]

[tool.ruff.per-file-ignores]
# Default implementation in abstract methods
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name:
f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
)

for key in ["observation", "achieved_goal", "desired_goal"]:
for key in ["achieved_goal", "desired_goal"]:
if key not in observation_space.spaces:
raise AssertionError(
f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
Expand Down
58 changes: 16 additions & 42 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ def __init__(
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
self.replay_buffer_class = replay_buffer_class
if replay_buffer_kwargs is None:
replay_buffer_kwargs = {}
self.replay_buffer_kwargs = replay_buffer_kwargs
if replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer
else:
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self._episode_storage = None

# Save train freq parameter, will be converted later to TrainFreq object
Expand Down Expand Up @@ -170,45 +174,21 @@ def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)

# Use DictReplayBuffer if needed
if self.replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer

elif self.replay_buffer_class == HerReplayBuffer:
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"

# If using offline sampling, we need a classic replay buffer too
if self.replay_buffer_kwargs.get("online_sampling", True):
replay_buffer = None
else:
replay_buffer = DictReplayBuffer(
self.buffer_size,
self.observation_space,
self.action_space,
device=self.device,
optimize_memory_usage=self.optimize_memory_usage,
)

self.replay_buffer = HerReplayBuffer(
self.env,
self.buffer_size,
device=self.device,
replay_buffer=replay_buffer,
**self.replay_buffer_kwargs,
)

if self.replay_buffer is None:
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
replay_buffer_kwargs = self.replay_buffer_kwargs.copy()
if issubclass(self.replay_buffer_class, HerReplayBuffer):
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"
replay_buffer_kwargs["env"] = self.env
self.replay_buffer = self.replay_buffer_class(
self.buffer_size,
self.observation_space,
self.action_space,
device=self.device,
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**self.replay_buffer_kwargs,
**replay_buffer_kwargs, # pytype:disable=wrong-keyword-args
)

self.policy = self.policy_class( # pytype:disable=not-instantiable
Expand Down Expand Up @@ -276,12 +256,7 @@ def _setup_learn(
# when using memory efficient replay buffer
# see https://github.com/DLR-RM/stable-baselines3/issues/46

# Special case when using HerReplayBuffer,
# the classic replay buffer is inside it when using offline sampling
if isinstance(self.replay_buffer, HerReplayBuffer):
replay_buffer = self.replay_buffer.replay_buffer
else:
replay_buffer = self.replay_buffer
replay_buffer = self.replay_buffer

truncate_last_traj = (
self.optimize_memory_usage
Expand Down Expand Up @@ -552,7 +527,6 @@ def collect_rollouts(

callback.on_rollout_start()
continue_training = True

while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
Expand Down

0 comments on commit c5adad8

Please sign in to comment.