Skip to content

Commit

Permalink
Fix render bug for vec env wrappers (#1525)
Browse files Browse the repository at this point in the history
* Fix render bug for vec env wrappers

* Fix tests and update changelog

* Better fix, backward compatible

* remove render_mode from VecEnv init

* Make DictObsVecEnv inherit from VecEnv

* format

* Fix env_is_wrapped

* try/except getting render mode ( (#1525 (comment))

* update version

* Fix env_is_wrapped in test_vec_extract_dict

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
  • Loading branch information
3 people committed Jun 7, 2023
1 parent 32778dd commit ffe26cc
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 21 deletions.
4 changes: 2 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a12 (WIP)
Release 2.0.0a13 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -64,7 +64,7 @@ Others:
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
- Tests envs are now checked too
- Added render test for ``VecEnv``
- Added render test for ``VecEnv`` and ``VecEnvWrapper``
- Update issue templates and env info saved with the model
- Changed ``seed()`` method return type from ``List`` to ``Sequence``
- Updated env checker doc and requirements for tuple spaces/goal envs
Expand Down
18 changes: 12 additions & 6 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,24 @@ def __init__(
num_envs: int,
observation_space: spaces.Space,
action_space: spaces.Space,
render_mode: Optional[str] = None,
):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
self.render_mode = render_mode
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
except AttributeError:
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
render_modes = [None for _ in range(num_envs)]

assert all(
render_mode == render_modes[0] for render_mode in render_modes
), "render_mode mode should be the same for all environments"
self.render_mode = render_modes[0]

def _reset_seeds(self) -> None:
"""
Expand Down Expand Up @@ -313,15 +321,13 @@ def __init__(
venv: VecEnv,
observation_space: Optional[spaces.Space] = None,
action_space: Optional[spaces.Space] = None,
render_mode: Optional[str] = None,
):
self.venv = venv
VecEnv.__init__(
self,

super().__init__(
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
render_mode=render_mode,
)
self.class_attributes = dict(inspect.getmembers(self.__class__))

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]):
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
)
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode)
super().__init__(len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)

Expand Down
4 changes: 1 addition & 3 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()

self.remotes[0].send(("get_attr", "render_mode"))
render_mode = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode)
super().__init__(len(env_fns), observation_space, action_space)

def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a12
2.0.0a13
9 changes: 9 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,4 +586,13 @@ def test_render(vec_env_class):
for _ in range(10):
vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)])
vec_env.render()

# Check that it still works with vec env wrapper
vec_env = VecFrameStack(vec_env, 2)
vec_env.render()
assert vec_env.render_mode == "rgb_array"
vec_env = VecNormalize(vec_env)
assert vec_env.render_mode == "rgb_array"
vec_env.render()

vec_env.close()
34 changes: 26 additions & 8 deletions tests/test_vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
from stable_baselines3.common.vec_env import VecEnv, VecExtractDictObs, VecMonitor


class DictObsVecEnv:
class DictObsVecEnv(VecEnv):
"""Custom Environment that produces observation in a dictionary like the procgen env"""

metadata = {"render.modes": ["human"]}
metadata = {"render_modes": ["human"]}

def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5
self.render_mode = None

def step_async(self, actions):
self.actions = actions
Expand All @@ -25,25 +26,42 @@ def step_wait(self):
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
{"terminal_observation": {"rgb": np.zeros((86, 86), dtype=np.float32)}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86))},
np.zeros((self.num_envs,)),
{"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)},
np.zeros((self.num_envs,), dtype=np.float32),
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)

def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86))}
return {"rgb": np.zeros((self.num_envs, 86, 86), dtype=np.float32)}

def render(self, close=False):
def render(self, mode=""):
pass

def get_attr(self, attr_name, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [getattr(self, attr_name) for _ in indices]

def close(self):
pass

def env_is_wrapped(self, wrapper_class, indices=None):
indices = range(self.num_envs) if indices is None else indices
return [False for _ in indices]

def env_method(self):
raise NotImplementedError # not used in the test

def set_attr(self, attr_name, value, indices=None) -> None:
raise NotImplementedError # not used in the test


def test_extract_dict_obs():
"""Test VecExtractDictObs"""
Expand Down

0 comments on commit ffe26cc

Please sign in to comment.