Skip to content

Commit

Permalink
Fix stable_baselines3/common/vec_env/vec_check_nan.py type hints (#…
Browse files Browse the repository at this point in the history
…1226)

* super() init style

* "async_step" arg to "event"; "news" to "dones"; improve docstring

* Remove vec_check_nan from mypy exclude

* Update changelog
  • Loading branch information
qgallouedec committed Dec 22, 2022
1 parent 9aff113 commit 5549b34
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Others:
- Fixed ``stable_baselines3/common/env_util.py`` type hints
- Fixed ``stable_baselines3/common/preprocessing.py`` type hints
- Fixed ``stable_baselines3/common/atari_wrappers.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_check_nan.py`` type hints
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device (~8% speed boost on GPU)
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ exclude = (?x)(
| stable_baselines3/common/vec_env/stacked_observations.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_check_nan.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
| stable_baselines3/common/vec_env/vec_frame_stack.py$
| stable_baselines3/common/vec_env/vec_monitor.py$
Expand Down
47 changes: 22 additions & 25 deletions stable_baselines3/common/vec_env/vec_check_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,40 @@ class VecCheckNan(VecEnvWrapper):
allowing you to know from what the NaN of inf originated from.
:param venv: the vectorized environment to wrap
:param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
:param warn_once: Whether or not to only warn once.
:param check_inf: Whether or not to check for +inf or -inf as well
:param raise_exception: Whether to raise a ValueError, instead of a UserWarning
:param warn_once: Whether to only warn once.
:param check_inf: Whether to check for +inf or -inf as well
"""

def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True):
VecEnvWrapper.__init__(self, venv)
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None:
super().__init__(venv)
self.raise_exception = raise_exception
self.warn_once = warn_once
self.check_inf = check_inf
self._actions = None
self._observations = None

self._user_warned = False

def step_async(self, actions: np.ndarray) -> None:
self._check_val(async_step=True, actions=actions)
self._actions: np.ndarray
self._observations: VecEnvObs

def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
self._actions = actions
self.venv.step_async(actions)

def step_wait(self) -> VecEnvStepReturn:
observations, rewards, news, infos = self.venv.step_wait()

self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)

observations, rewards, dones, infos = self.venv.step_wait()
self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones)
self._observations = observations
return observations, rewards, news, infos
return observations, rewards, dones, infos

def reset(self) -> VecEnvObs:
observations = self.venv.reset()
self._actions = None

self._check_val(async_step=False, observations=observations)

self._check_val(event="reset", observations=observations)
self._observations = observations
return observations

def _check_val(self, *, async_step: bool, **kwargs) -> None:
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
Expand All @@ -72,13 +68,14 @@ def _check_val(self, *, async_step: bool, **kwargs) -> None:

msg += ".\r\nOriginated from the "

if not async_step:
if self._actions is None:
msg += "environment observation (at reset)"
else:
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
else:
if event == "reset":
msg += "environment observation (at reset)"
elif event == "step_wait":
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
elif event == "step_async":
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
else:
raise ValueError("Internal error.")

if self.raise_exception:
raise ValueError(msg)
Expand Down

0 comments on commit 5549b34

Please sign in to comment.