Skip to content

Commit

Permalink
Fix render_mode when loading VecNormalize (#1671)
Browse files Browse the repository at this point in the history
* Fix render_mode when loading VecNormalize

* Switch from isort to ruff, and cap black version

* Add test and update changelog
  • Loading branch information
araffin committed Sep 12, 2023
1 parent 57dbefe commit 9971276
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 21 deletions.
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pip install -e .[docs,tests,extra]

## Codestyle

We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.

**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
Expand All @@ -63,7 +63,7 @@ def my_function(arg1: type1, arg2: type2) -> returntype:

Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.

Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave or @Miffyli).
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
A PR must pass the Continuous Integration tests to be merged with the master branch.


Expand All @@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
make type
```

Codestyle check with `black`, `isort` and `ruff`:
Codestyle check with `black`, and `ruff` (`isort` rules):

```
make check-codestyle
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ lint:

format:
# Sort imports
isort ${LINT_PATHS}
ruff --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
11 changes: 9 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Changelog
==========

Release 2.2.0a1 (WIP)
Release 2.2.0a2 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version

New Features:
^^^^^^^^^^^^^
Expand All @@ -18,14 +19,19 @@ New Features:
`RL Zoo`_
^^^^^^^^^

`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``

Bug Fixes:
^^^^^^^^^^
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
- Performs unscaling of actions in collect_rollout in OnPolicyAlgorithm (@PatrickHelm)
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``


Deprecations:
Expand Down Expand Up @@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).

.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
.. _SBX: https://github.com/araffin/sbx

Contributors:
-------------
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ max-complexity = 15
[tool.black]
line-length = 127

[tool.isort]
profile = "black"
line_length = 127
src_paths = ["stable_baselines3"]

[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@
# Type check
"pytype",
"mypy",
# Lint code (flake8 replacement)
"ruff",
# Sort imports
"isort>=5.0",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",
# Reformat
"black",
"black>=23.9.1,<24",
],
"docs": [
"sphinx>=5.3,<7.0",
Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def set_venv(self, venv: VecEnv) -> None:
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode

# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a1
2.2.0a2
23 changes: 20 additions & 3 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def make_env():
return Monitor(gym.make(ENV_ID))


def make_env_render():
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))


def make_dict_env():
return Monitor(DummyDictEnv())

Expand Down Expand Up @@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize():
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)


@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_env):
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_gym_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0

orig_venv = DummyVecEnv([make_env])
orig_venv = DummyVecEnv([make_gym_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
assert orig_venv.render_mode is None
assert norm_venv.render_mode is None

_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
Expand All @@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env):

path = tmp_path / "vec_normalize"
norm_venv.save(path)
assert orig_venv.render_mode is None
deserialized = VecNormalize.load(path, venv=orig_venv)
assert deserialized.render_mode is None
check_vec_norm_equal(norm_venv, deserialized)

# Check that render mode is properly updated
vec_env = DummyVecEnv([make_env_render])
assert vec_env.render_mode == "rgb_array"
# Test that loading and wrapping keep the correct render mode
if make_gym_env == make_env:
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
assert VecNormalize(vec_env).render_mode == "rgb_array"


def test_get_original():
venv = _make_warmstart_cartpole()
Expand Down

0 comments on commit 9971276

Please sign in to comment.