Skip to content

Commit

Permalink
Add sticky actions for Atari games (#1286)
Browse files Browse the repository at this point in the history
* repeat_action_probability

* Add test

* Undo atari wrapper doc change since CI fails

* remove action_repeat_probability from make_atari_env

* Add sticky action wrapper and improve documentation

* Update changelog

* handle the case noop_max=0

* Update tests

* Comply to ALE implementation

* Reorder doc

* Add doc warning and don't wrap with sticky action when not needed

* fix docstring and reorder

* Move `action_repeat_probability` args at the last position

* Add ref

* Update doc and wrap with frameskip only if needed

* Update changelog

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Jan 26, 2023
1 parent 637988c commit 5ee9009
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 41 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a2 (WIP)
Release 1.8.0a3 (WIP)
--------------------------


Expand All @@ -14,6 +14,8 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
81 changes: 62 additions & 19 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,39 @@
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn


class StickyActionEnv(gym.Wrapper):
"""
Sticky action.
Paper: https://arxiv.org/abs/1709.06009
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
:param env: Environment to wrap
:param action_repeat_probability: Probability of repeating the last action
"""

def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
super().__init__(env)
self.action_repeat_probability = action_repeat_probability
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

def reset(self, **kwargs) -> GymObs:
self._sticky_action = 0 # NOOP
return self.env.reset(**kwargs)

def step(self, action: int) -> GymStepReturn:
if self.np_random.random() >= self.action_repeat_probability:
self._sticky_action = action
return self.env.step(self._sticky_action)


class NoopResetEnv(gym.Wrapper):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param env: the environment to wrap
:param noop_max: the maximum value of no-ops to run
:param env: Environment to wrap
:param noop_max: Maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
Expand Down Expand Up @@ -47,7 +73,7 @@ class FireResetEnv(gym.Wrapper):
"""
Take action on reset for environments that are fixed until firing.
:param env: the environment to wrap
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand All @@ -71,7 +97,7 @@ class EpisodicLifeEnv(gym.Wrapper):
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
:param env: the environment to wrap
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand Down Expand Up @@ -120,9 +146,11 @@ def reset(self, **kwargs) -> np.ndarray:
class MaxAndSkipEnv(gym.Wrapper):
"""
Return only every ``skip``-th frame (frameskipping)
and return the max between the two last frames.
:param env: the environment
:param skip: number of ``skip``-th frame
:param env: Environment to wrap
:param skip: Number of ``skip``-th frame
The same action will be taken ``skip`` times.
"""

def __init__(self, env: gym.Env, skip: int = 4) -> None:
Expand Down Expand Up @@ -159,9 +187,9 @@ def step(self, action: int) -> GymStepReturn:

class ClipRewardEnv(gym.RewardWrapper):
"""
Clips the reward to {+1, 0, -1} by its sign.
Clip the reward to {+1, 0, -1} by its sign.
:param env: the environment
:param env: Environment to wrap
"""

def __init__(self, env: gym.Env) -> None:
Expand All @@ -182,9 +210,9 @@ class WarpFrame(gym.ObservationWrapper):
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.
:param env: the environment
:param width:
:param height:
:param env: Environment to wrap
:param width: New frame width
:param height: New frame height
"""

def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
Expand Down Expand Up @@ -213,20 +241,29 @@ class AtariWrapper(gym.Wrapper):
Specifically:
* NoopReset: obtain initial state by taking random number of no-ops on reset.
* Noop reset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost.
* Resize to a square image: 84x84 by default
* Grayscale observation
* Clip reward to {-1, 0, 1}
* Sticky actions: disabled by default
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
for a visual explanation.
.. warning::
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
:param env: gym environment
:param noop_max: max number of no-ops
:param frame_skip: the frequency at which the agent experiences the game.
:param screen_size: resize Atari frame
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
:param env: Environment to wrap
:param noop_max: Max number of no-ops
:param frame_skip: Frequency at which the agent experiences the game.
This correspond to repeating the action ``frame_skip`` times.
:param screen_size: Resize Atari frame
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
:param action_repeat_probability: Probability of repeating the last action
"""

def __init__(
Expand All @@ -237,9 +274,15 @@ def __init__(
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
action_repeat_probability: float = 0.0,
) -> None:
env = NoopResetEnv(env, noop_max=noop_max)
env = MaxAndSkipEnv(env, skip=frame_skip)
if action_repeat_probability > 0.0:
env = StickyActionEnv(env, action_repeat_probability)
if noop_max > 0:
env = NoopResetEnv(env, noop_max=noop_max)
# frame_skip=1 is the same as no frame-skip (action repeat)
if frame_skip > 1:
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a2
1.8.0a3
64 changes: 44 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import stable_baselines3 as sb3
from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
Expand Down Expand Up @@ -55,30 +55,54 @@ def test_make_vec_env_func_checker():
env.close()


@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"])
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)])
def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
# Use Asterix as it does not requires fire reset
@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4", "AsterixNoFrameskip-v4"])
@pytest.mark.parametrize("noop_max", [0, 10])
@pytest.mark.parametrize("action_repeat_probability", [0.0, 0.25])
@pytest.mark.parametrize("frame_skip", [1, 4])
@pytest.mark.parametrize("screen_size", [60])
@pytest.mark.parametrize("terminal_on_life_loss", [True, False])
@pytest.mark.parametrize("clip_reward", [True])
def test_make_atari_env(
env_id, noop_max, action_repeat_probability, frame_skip, screen_size, terminal_on_life_loss, clip_reward
):
n_envs = 2
wrapper_kwargs = {
"noop_max": noop_max,
"action_repeat_probability": action_repeat_probability,
"frame_skip": frame_skip,
"screen_size": screen_size,
"terminal_on_life_loss": terminal_on_life_loss,
"clip_reward": clip_reward,
}
venv = make_atari_env(
env_id,
n_envs=2,
wrapper_kwargs=wrapper_kwargs,
monitor_dir=None,
seed=0,
)

assert env.num_envs == n_envs
assert venv.num_envs == n_envs

obs = env.reset()
needs_fire_reset = env_id == "BreakoutNoFrameskip-v4"
expected_frame_number_low = frame_skip * 2 if needs_fire_reset else 0 # FIRE - UP on reset
expected_frame_number_high = expected_frame_number_low + noop_max
expected_shape = (n_envs, screen_size, screen_size, 1)

new_obs, reward, _, _ = env.step([env.action_space.sample() for _ in range(n_envs)])
obs = venv.reset()
frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number in frame_numbers:
assert expected_frame_number_low <= frame_number <= expected_frame_number_high
assert obs.shape == expected_shape

assert obs.shape == new_obs.shape
new_obs, reward, _, _ = venv.step([venv.action_space.sample() for _ in range(n_envs)])

# Wrapped into DummyVecEnv
wrapped_atari_env = env.envs[0]
if wrapper_kwargs is not None:
assert obs.shape == (n_envs, 60, 60, 1)
assert wrapped_atari_env.observation_space.shape == (60, 60, 1)
assert not isinstance(wrapped_atari_env.env, ClipRewardEnv)
else:
assert obs.shape == (n_envs, 84, 84, 1)
assert wrapped_atari_env.observation_space.shape == (84, 84, 1)
assert isinstance(wrapped_atari_env.env, ClipRewardEnv)
new_frame_numbers = [env.unwrapped.ale.getEpisodeFrameNumber() for env in venv.envs]
for frame_number, new_frame_number in zip(frame_numbers, new_frame_numbers):
assert new_frame_number - frame_number == frame_skip
assert new_obs.shape == expected_shape
if clip_reward:
assert np.max(np.abs(reward)) < 1.0


Expand Down

0 comments on commit 5ee9009

Please sign in to comment.