Skip to content

Commit

Permalink
Add check for common mistake when mixing Gym/VecEnv API (#1696)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Sep 25, 2023
1 parent b85fa75 commit 2ca94cb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a4 (WIP)
Release 2.2.0a5 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -13,6 +13,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)

Bug Fixes:
^^^^^^^^^^
Expand Down
11 changes: 11 additions & 0 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ def predict(
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)

# Check for common mistake that the user does not mix Gym/VecEnv API
# Tuple obs are not supported by SB3, so we can safely do that check
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
raise ValueError(
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)

observation, vectorized_env = self.obs_to_tensor(observation)

with th.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a4
2.2.0a5
9 changes: 9 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,12 @@ def test_subclassed_space_env(model_class):
model.learn(300)
obs, _ = env.reset()
env.step(model.predict(obs))


def test_mixing_gym_vecenv_api():
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env)
# Reset return a tuple (obs, info)
wrong_obs = env.reset()
with pytest.raises(ValueError, match="mixing Gym API"):
model.predict(wrong_obs)

0 comments on commit 2ca94cb

Please sign in to comment.