Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Framestack obs env_checker #1574

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Bug Fixes:
for ``Inf`` and ``NaN`` (@lutogniew)
- Fixed HER ``truncate_last_trajectory()`` (@lbergmann1)
- Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz)
- Fixed ``env_checker`` for ``FrameStack`` observation (@corentinlger)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1410,4 +1411,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @corentinlger
10 changes: 8 additions & 2 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.wrappers.frame_stack import LazyFrames

from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
Expand Down Expand Up @@ -171,7 +172,9 @@ def _check_goal_env_compute_reward(
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"


def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
def _check_obs(
obs: Union[tuple, dict, np.ndarray, int, LazyFrames], observation_space: spaces.Space, method_name: str
) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
Expand All @@ -187,7 +190,10 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
# `sample()` will return a np.int64 instead of an int
assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
# Check if obs is a ndarray or a FrameStacking of ndarrays
assert isinstance(
obs, (np.ndarray, LazyFrames)
), f"The observation returned by `{method_name}()` method must be a numpy array"

# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
Expand Down