Skip to content

Commit

Permalink
Merge branch 'fix_tests' into feat/gymnasium-support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 20, 2023
2 parents 524d0bb + 986e6c0 commit 3dadca4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
# Unpack
obs, reward, terminated, truncated, info = data

if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "step")
_check_goal_env_compute_reward(obs, env, reward, info) # type: ignore[arg-type]
elif isinstance(observation_space, spaces.Dict):
if isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"

# Additional checks for GoalEnvs
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "step")
_check_goal_env_compute_reward(obs, env, float(reward), info)

if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `step()` must match the observation "
Expand Down

0 comments on commit 3dadca4

Please sign in to comment.