Skip to content

Commit

Permalink
Fix stable_baselines3/common/atari_wrappers.py type hints (#1216)
Browse files Browse the repository at this point in the history
* Fix `stable_baselines3/common/atari_wrappers.py` type hints

* Fix initialization

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
araffin and qgallouedec committed Dec 18, 2022
1 parent 07094c3 commit 0c1bc0b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
9 changes: 5 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ Others:
- Fixed flake8 config to be compatible with flake8 6+
- Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv``
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
- Fixed ``tests/test_distributions.py`` type hint
- Fixed ``stable_baselines3/common/type_aliases.py`` type hint
- Fixed ``stable_baselines3/common/torch_layers.py`` type hint
- Fixed ``stable_baselines3/common/env_util.py`` type hint
- Fixed ``tests/test_distributions.py`` type hints
- Fixed ``stable_baselines3/common/type_aliases.py`` type hints
- Fixed ``stable_baselines3/common/torch_layers.py`` type hints
- Fixed ``stable_baselines3/common/env_util.py`` type hints
- Fixed ``stable_baselines3/common/preprocessing.py`` type hints
- Fixed ``stable_baselines3/common/atari_wrappers.py`` type hints
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ follow_imports = silent
show_error_codes = True
exclude = (?x)(
stable_baselines3/a2c/a2c.py$
| stable_baselines3/common/atari_wrappers.py$
| stable_baselines3/common/base_class.py$
| stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
Expand Down
28 changes: 14 additions & 14 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class NoopResetEnv(gym.Wrapper):
:param noop_max: the maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30):
gym.Wrapper.__init__(self, env)
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
super().__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
Expand Down Expand Up @@ -50,8 +50,8 @@ class FireResetEnv(gym.Wrapper):
:param env: the environment to wrap
"""

def __init__(self, env: gym.Env):
gym.Wrapper.__init__(self, env)
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3

Expand All @@ -74,8 +74,8 @@ class EpisodicLifeEnv(gym.Wrapper):
:param env: the environment to wrap
"""

def __init__(self, env: gym.Env):
gym.Wrapper.__init__(self, env)
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.lives = 0
self.was_real_done = True

Expand Down Expand Up @@ -119,8 +119,8 @@ class MaxAndSkipEnv(gym.Wrapper):
:param skip: number of ``skip``-th frame
"""

def __init__(self, env: gym.Env, skip: int = 4):
gym.Wrapper.__init__(self, env)
def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
self._skip = skip
Expand All @@ -134,7 +134,7 @@ def step(self, action: int) -> GymStepReturn:
:return: observation, reward, done, information
"""
total_reward = 0.0
done = None
done = False
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2:
Expand All @@ -161,8 +161,8 @@ class ClipRewardEnv(gym.RewardWrapper):
:param env: the environment
"""

def __init__(self, env: gym.Env):
gym.RewardWrapper.__init__(self, env)
def __init__(self, env: gym.Env) -> None:
super().__init__(env)

def reward(self, reward: float) -> float:
"""
Expand All @@ -184,8 +184,8 @@ class WarpFrame(gym.ObservationWrapper):
:param height:
"""

def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
gym.ObservationWrapper.__init__(self, env)
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
super().__init__(env)
self.width = width
self.height = height
self.observation_space = spaces.Box(
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
):
) -> None:
env = NoopResetEnv(env, noop_max=noop_max)
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
Expand Down

0 comments on commit 0c1bc0b

Please sign in to comment.