Skip to content

Commit

Permalink
Fix stable_baselines3/common/preprocessing.py type hints (#1217)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 18, 2022
1 parent 6d55a09 commit 07094c3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Others:
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
- Fixed ``stable_baselines3/common/env_util.py`` type hint
- Fixed ``stable_baselines3/common/preprocessing.py`` type hints
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ exclude = (?x)(
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/on_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/preprocessing.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/utils.py$
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def preprocess_obs(

elif isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
Expand Down Expand Up @@ -155,7 +156,7 @@ def get_obs_shape(
else:
return (int(observation_space.n),)
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]

else:
raise NotImplementedError(f"{observation_space} observation space is not supported")
Expand Down

0 comments on commit 07094c3

Please sign in to comment.