Skip to content

Commit

Permalink
Make check_env assertions in regards to observation_space more action…
Browse files Browse the repository at this point in the history
…able (#1400)

* add instructions for running single tests in the README, add assertions for observation_space

* update changelog

* address linting warnings

* correct pytest command in the README

* correct review comments, run make commit-checks

* truncate lines that are too long

* address make lint warning about checking module availability

* fix tests

* use f-strings for formatting assertion messages

* fix type issue

* Refactor tests, improve error messages

---------

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
FieteO and araffin committed Mar 29, 2023
1 parent c5adad8 commit b6aa507
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 13 deletions.
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pip install stable-baselines3[extra]
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
```
```sh
pip install stable-baselines3
```

Expand Down Expand Up @@ -194,20 +194,32 @@ Actions `gym.spaces`:


## Testing the installation
All unit tests in stable baselines3 can be run using `pytest` runner:
### Install dependencies
```sh
pip install -e .[docs,tests,extra]
```
pip install pytest pytest-cov
### Run tests
All unit tests in stable baselines3 can be run using `pytest` runner:
```sh
make pytest
```
To run a single test file:
```sh
python3 -m pytest -v tests/test_env_checker.py
```
To run a single test:
```sh
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
```
```sh
pip install pytype mypy
make type
```

Codestyle check with `ruff`:
```
```sh
pip install ruff
make lint
```
Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a10 (WIP)
Release 1.8.0a11 (WIP)
--------------------------

.. warning::
Expand All @@ -31,6 +31,7 @@ New Features:
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)
- Added multiprocessing support for ``HerReplayBuffer``
- ``HerReplayBuffer`` now supports all datatypes supported by ``ReplayBuffer``
- Provide more helpful failure messages when validating the ``observation_space`` of custom gym environments using ``check_env``` (@FieteO)


`SB3-Contrib`_
Expand Down Expand Up @@ -1251,4 +1252,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO
23 changes: 23 additions & 0 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,29 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"

# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
# check obs dimensions, dtype and bounds
assert observation_space.shape == obs.shape, (
f"The observation returned by the `{method_name}()` method does not match the shape "
f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}"
)
assert observation_space.dtype == obs.dtype, (
f"The observation returned by the `{method_name}()` method does not match the data type "
f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
assert np.all(obs >= observation_space.low), (
f"The observation returned by the `{method_name}()` method does not match the lower bound "
f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, "
f"actual min value: {np.min(obs)} at index {np.argmin(obs)}"
)
assert np.all(obs <= observation_space.high), (
f"The observation returned by the `{method_name}()` method does not match the upper bound "
f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, "
f"actual max value: {np.max(obs)} at index {np.argmax(obs)}"
)

assert observation_space.contains(
obs
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a10
1.8.0a11
77 changes: 77 additions & 0 deletions tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,80 @@ def test_check_env_dict_action():

with pytest.warns(Warning):
check_env(env=test_env, warn=True)


@pytest.mark.parametrize(
"obs_tuple",
[
# Above upper bound
(
spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1",
),
# Below lower bound
(
spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32),
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0",
),
# Wrong dtype
(
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
np.array([1.0, 1.5, 0.5], dtype=np.float64),
r"Expected: float32, actual dtype: float64",
),
# Wrong shape
(
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32),
r"Expected: \(3,\), actual shape: \(2, 3\)",
),
# Wrong shape (dict obs)
(
spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}),
{"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)},
r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)",
),
# Wrong shape (multi discrete)
(
spaces.MultiDiscrete([3, 3]),
np.array([[2, 0]]),
r"Expected: \(2,\), actual shape: \(1, 2\)",
),
# Wrong shape (multi binary)
(
spaces.MultiBinary(3),
np.array([[1, 0, 0]]),
r"Expected: \(3,\), actual shape: \(1, 3\)",
),
],
)
@pytest.mark.parametrize(
# Check when it happens at reset or during step
"method",
["reset", "step"],
)
def test_check_env_detailed_error(obs_tuple, method):
"""
Check that the env checker returns more detail error
when the observation is not in the obs space.
"""
observation_space, wrong_obs, error_message = obs_tuple
good_obs = observation_space.sample()

class TestEnv(gym.Env):
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)

def reset(self):
return wrong_obs if method == "reset" else good_obs

def step(self, action):
obs = wrong_obs if method == "step" else good_obs
return obs, 0.0, True, {}

TestEnv.observation_space = observation_space

test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env)
7 changes: 2 additions & 5 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import os
import sys
import time
Expand Down Expand Up @@ -233,11 +234,7 @@ def test_report_video_to_tensorboard(tmp_path, read_log, capsys):


def is_moviepy_installed():
try:
import moviepy
except ModuleNotFoundError:
return False
return True
return importlib.util.find_spec("moviepy") is not None


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
Expand Down

0 comments on commit b6aa507

Please sign in to comment.