Skip to content

Commit

Permalink
Refactor observation stacking (#1238)
Browse files Browse the repository at this point in the history
* refactor stacking obs

* Improve docstring

* remove all StackedDictObservations

* Update tests and make stacked obs clearer

* Fix type check

* fix stacked_observation_space

* undo init change, deprecate StackedDictObservations

* deprecate stack_observation_space

* type hints

* ignore pytype errors

* undo vecenv doc change

* Deprecation warning in StackedDictObs doctstring

* Fix vec_env.rst

* Fix __all__ sorting

* fix pytype ignore statement

* Update docstring

* stack

* Remove n_stack

* Update changelog

* Simplify code

* Rename test file

* Re-use variable for shift

* Fix doc build

* Remove pytype comment

* Disable pytype error

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Feb 6, 2023
1 parent 411ff69 commit 2e4a450
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 234 deletions.
6 changes: 0 additions & 6 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ StackedObservations
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
:members:

StackedDictObservations
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations
:members:

VecNormalize
~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ Changelog
==========


Release 1.8.0a3 (WIP)
Release 1.8.0a4 (WIP)
--------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)

New Features:
^^^^^^^^^^^^^
Expand All @@ -36,6 +37,7 @@ Others:
- Fixed ``tests/test_tensorboard.py`` type hint
- Fixed ``tests/test_vec_normalize.py`` type hint
- Fixed ``stable_baselines3/common/monitor.py`` type hint
- Added tests for StackedObservations

Documentation:
^^^^^^^^^^^^^^
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ exclude = (?x)(
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/stacked_observations.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
Expand Down
3 changes: 1 addition & 2 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
Expand Down Expand Up @@ -78,7 +78,6 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedDictObservations",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
Expand Down
316 changes: 128 additions & 188 deletions stable_baselines3/common/vec_env/stacked_observations.py

Large diffs are not rendered by default.

47 changes: 12 additions & 35 deletions stable_baselines3/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,40 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
from gym import spaces

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations


class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.
Uses the StackedObservations class, or StackedDictObservations depending on the observations space
:param venv: the vectorized environment to wrap
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""

def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
self.venv = venv
self.n_stack = n_stack

wrapped_obs_space = venv.observation_space

if isinstance(wrapped_obs_space, spaces.Box):
assert not isinstance(
channels_order, dict
), f"Expected None or string for channels_order but received {channels_order}"
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)

elif isinstance(wrapped_obs_space, spaces.Dict):
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)

else:
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"

observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)

def step_wait(
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
observations, rewards, dones, infos = self.venv.step_wait()

observations, infos = self.stackedobs.update(observations, dones, infos)

observations, infos = self.stacked_obs.update(observations, dones, infos)
return observations, rewards, dones, infos

def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch

observation = self.stackedobs.reset(observation)
observation = self.stacked_obs.reset(observation)
return observation

def close(self) -> None:
self.venv.close()
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a3
1.8.0a4

0 comments on commit 2e4a450

Please sign in to comment.