Skip to content

Commit

Permalink
Add support for dict/tuple obs space for VecCheckNaN (#1348)
Browse files Browse the repository at this point in the history
* Add support for dict/tuple obs space for VecCheckNaN

* Handle list too

* Address comments from code review

* Ignore B028 (explicit stack level)
  • Loading branch information
araffin committed Feb 27, 2023
1 parent 085bdd5 commit ed8783c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 14 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a6 (WIP)
Release 1.8.0a7 (WIP)
--------------------------


Expand All @@ -18,6 +18,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1230,4 +1231,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ exclude = (?x)(

[flake8]
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
# ignore explicit stack level
ignore = W503,W504,E203,E231,B028
# Ignore import not used when aliases are defined
per-file-ignores =
# Default implementation in abstract methods
Expand Down
10 changes: 7 additions & 3 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from gym import spaces

from stable_baselines3.common.preprocessing import is_image_space_channels_first
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan


Expand Down Expand Up @@ -380,6 +380,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
if not skip_render_check:
_check_render(env, warn=warn) # pragma: no cover

# The check only works with numpy arrays
if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space):
try:
check_for_nested_spaces(env.observation_space)
# The check doesn't support nested observations/dict actions
# A warning about it has already been emitted
_check_nan(env)
except NotImplementedError:
pass
39 changes: 32 additions & 7 deletions stable_baselines3/common/vec_env/vec_check_nan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import List, Tuple

import numpy as np
from gym import spaces

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper

Expand All @@ -26,6 +28,8 @@ def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool

self._actions: np.ndarray
self._observations: VecEnvObs
if isinstance(venv.action_space, spaces.Dict):
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")

def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
Expand All @@ -44,19 +48,40 @@ def reset(self) -> VecEnvObs:
self._observations = observations
return observations

def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.
:param name: Name of the value being check
:param value: Value (numpy array) to check
:return: A list of issues found.
"""
found = []
has_nan = np.any(np.isnan(value))
has_inf = self.check_inf and np.any(np.isinf(value))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
return found

def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return

found = []
for name, val in kwargs.items():
has_nan = np.any(np.isnan(val))
has_inf = self.check_inf and np.any(np.isinf(val))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
for name, value in kwargs.items():
if isinstance(value, (np.ndarray, list)):
found += self.check_array_value(name, np.asarray(value))
elif isinstance(value, dict):
for inner_name, inner_val in value.items():
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
elif isinstance(value, tuple):
for idx, inner_val in enumerate(value):
found += self.check_array_value(f"{name}.{idx}", inner_val)
else:
raise TypeError(f"Unsupported observation type {type(value)}.")

if found:
self._user_warned = True
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a6
1.8.0a7

0 comments on commit ed8783c

Please sign in to comment.