Skip to content

Commit

Permalink
Fix image-based normalized env loading (#1321)
Browse files Browse the repository at this point in the history
* Fix

* Add test

* Update changelog

* fix memory error avoidance

* Update version

* image env test

* black

* check_shape_equal

* check shape equal in vecnormalize

* Allow spaces not to be box or dict

* rm `test_save_load_vecnormalized_image` in favor of `test_vec_env`

* Remove unused imports

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people committed Feb 15, 2023
1 parent 7a1e429 commit 12e9917
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 7 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a4 (WIP)
Release 1.8.0a5 (WIP)
--------------------------


Expand All @@ -29,6 +29,7 @@ Bug Fixes:
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
- Added the argument ``dtype`` (default to ``float32``) to the noise for consistency with gym action (@sidney-tio)
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
- Fixed loading of normalized image-based environments

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -212,7 +213,7 @@ Bug Fixes:
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde)
- Fixed loading saved model with different number of envrionments
- Fixed loading saved model with different number of environments
- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave)
Expand Down
18 changes: 18 additions & 0 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,24 @@ def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, actio
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")


def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, "spaces must have the same shape"


def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
Expand Down
9 changes: 6 additions & 3 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import pickle
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -159,10 +160,12 @@ def set_venv(self, venv: VecEnv) -> None:
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
VecEnvWrapper.__init__(self, venv)
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))

# Check only that the observation_space match
utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space)
# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
self.returns = np.zeros(self.num_envs)

def step_wait(self) -> VecEnvStepReturn:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a4
1.8.0a5
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import (
check_shape_equal,
get_parameters_by_name,
get_system_info,
is_vectorized_observation,
Expand Down Expand Up @@ -509,3 +510,23 @@ def test_is_vectorized_observation():
discrete_obs = np.ones((1, 1), dtype=np.int8)
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)


def test_check_shape_equal():
space1 = spaces.Box(low=0, high=1, shape=(2, 2))
space2 = spaces.Box(low=-1, high=1, shape=(2, 2))
check_shape_equal(space1, space2)

space1 = spaces.Box(low=0, high=1, shape=(2, 2))
space2 = spaces.Box(low=-1, high=2, shape=(3, 3))
with pytest.raises(AssertionError):
check_shape_equal(space1, space2)

space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(2, 2)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
check_shape_equal(space1, space2)

space1 = spaces.Dict({"key1": spaces.Box(low=0, high=1, shape=(2, 2)), "key2": spaces.Box(low=0, high=1, shape=(2, 2))})
space2 = spaces.Dict({"key1": spaces.Box(low=-1, high=2, shape=(3, 3)), "key2": spaces.Box(low=-1, high=2, shape=(2, 2))})
with pytest.raises(AssertionError):
check_shape_equal(space1, space2)
7 changes: 6 additions & 1 deletion tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gym import spaces

from stable_baselines3 import SAC, TD3, HerReplayBuffer
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env import (
Expand Down Expand Up @@ -118,6 +119,10 @@ def make_dict_env():
return Monitor(DummyDictEnv())


def make_image_env():
return Monitor(FakeImageEnv())


def check_rms_equal(rmsa, rmsb):
if isinstance(rmsa, dict):
for key in rmsa.keys():
Expand Down Expand Up @@ -244,7 +249,7 @@ def test_obs_rms_vec_normalize():
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)


@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
Expand Down

0 comments on commit 12e9917

Please sign in to comment.