-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor observation stacking (#1238)
* refactor stacking obs * Improve docstring * remove all StackedDictObservations * Update tests and make stacked obs clearer * Fix type check * fix stacked_observation_space * undo init change, deprecate StackedDictObservations * deprecate stack_observation_space * type hints * ignore pytype errors * undo vecenv doc change * Deprecation warning in StackedDictObs doctstring * Fix vec_env.rst * Fix __all__ sorting * fix pytype ignore statement * Update docstring * stack * Remove n_stack * Update changelog * Simplify code * Rename test file * Re-use variable for shift * Fix doc build * Remove pytype comment * Disable pytype error --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
- Loading branch information
1 parent
411ff69
commit 2e4a450
Showing
8 changed files
with
459 additions
and
234 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
316 changes: 128 additions & 188 deletions
316
stable_baselines3/common/vec_env/stacked_observations.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,40 @@ | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
from gym import spaces | ||
|
||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper | ||
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations | ||
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations | ||
|
||
|
||
class VecFrameStack(VecEnvWrapper): | ||
""" | ||
Frame stacking wrapper for vectorized environment. Designed for image observations. | ||
Uses the StackedObservations class, or StackedDictObservations depending on the observations space | ||
:param venv: the vectorized environment to wrap | ||
:param venv: Vectorized environment to wrap | ||
:param n_stack: Number of frames to stack | ||
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. | ||
If None, automatically detect channel to stack over in case of image observation or default to "last" (default). | ||
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces | ||
""" | ||
|
||
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): | ||
self.venv = venv | ||
self.n_stack = n_stack | ||
|
||
wrapped_obs_space = venv.observation_space | ||
|
||
if isinstance(wrapped_obs_space, spaces.Box): | ||
assert not isinstance( | ||
channels_order, dict | ||
), f"Expected None or string for channels_order but received {channels_order}" | ||
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) | ||
|
||
elif isinstance(wrapped_obs_space, spaces.Dict): | ||
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) | ||
|
||
else: | ||
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") | ||
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None: | ||
assert isinstance( | ||
venv.observation_space, (spaces.Box, spaces.Dict) | ||
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces" | ||
|
||
observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) | ||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space) | ||
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order) | ||
observation_space = self.stacked_obs.stacked_observation_space | ||
super().__init__(venv, observation_space=observation_space) | ||
|
||
def step_wait( | ||
self, | ||
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: | ||
observations, rewards, dones, infos = self.venv.step_wait() | ||
|
||
observations, infos = self.stackedobs.update(observations, dones, infos) | ||
|
||
observations, infos = self.stacked_obs.update(observations, dones, infos) | ||
return observations, rewards, dones, infos | ||
|
||
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: | ||
""" | ||
Reset all environments | ||
""" | ||
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch | ||
|
||
observation = self.stackedobs.reset(observation) | ||
observation = self.stacked_obs.reset(observation) | ||
return observation | ||
|
||
def close(self) -> None: | ||
self.venv.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
1.8.0a3 | ||
1.8.0a4 |
Oops, something went wrong.