Skip to content

Commit

Permalink
Fix Atari Roms download, enable RUF linting (#1379)
Browse files Browse the repository at this point in the history
* Add extra no Atari and fix CI for forks

* Enable ruff rules

* Change to no roms
  • Loading branch information
araffin committed Mar 12, 2023
1 parent 10e8386 commit 470771b
Show file tree
Hide file tree
Showing 21 changed files with 69 additions and 59 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1
ATARI_ROMS: ${{ secrets.ATARI_ROMS }}

# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
Expand All @@ -37,11 +36,11 @@ jobs:
# Install Atari Roms
pip install autorom
wget $ATARI_ROMS
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra,tests,docs]
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a8 (WIP)
Release 1.8.0a9 (WIP)
--------------------------


Expand Down Expand Up @@ -46,6 +46,7 @@ Others:
- Moved from ``setup.cg`` to ``pyproject.toml`` configuration file
- Switched from ``flake8`` to ``ruff``
- Upgraded AutoROM to latest version
- Added ``extra_no_roms`` option for package installation without Atari Roms

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
line-length = 127
# Assume Python 3.7
target-version = "py37"
# TODO(antonin): activate "RUF" https://beta.ruff.rs/docs/rules/#ruff-specific-rules-ruf
select = ["E", "F", "B", "UP", "C90"]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
ignore = []

[tool.ruff.per-file-ignores]
Expand Down
40 changes: 25 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@
""" # noqa:E501

# Atari Games download is sometimes problematic:
# https://github.com/Farama-Foundation/AutoROM/issues/39
# That's why we define extra packages without it.
extra_no_roms = [
# For render
"opencv-python",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py==0.7.4",
"pillow",
]

extra_packages = extra_no_roms + [ # noqa: RUF005
# For atari roms,
"autorom[accept-rom-license]~=0.5.5",
]


setup(
name="stable_baselines3",
Expand Down Expand Up @@ -119,21 +142,8 @@
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": [
# For render
"opencv-python",
# For atari games,
"ale-py==0.7.4",
"autorom[accept-rom-license]~=0.5.5",
"pillow",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
],
"extra": extra_packages,
"extra_no_roms": extra_no_roms,
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
self._skip = skip

def step(self, action: int) -> GymStepReturn:
Expand Down
20 changes: 10 additions & 10 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
shape = arr.shape
if len(shape) < 3:
shape = shape + (1,)
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])

def size(self) -> int:
Expand Down Expand Up @@ -199,13 +199,13 @@ def __init__(
)
self.optimize_memory_usage = optimize_memory_usage

self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

if optimize_memory_usage:
# `observations` contains also the next observation
self.next_observations = None
else:
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)

self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)

Expand Down Expand Up @@ -243,8 +243,8 @@ def add(
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))

# Same, for actions
action = action.reshape((self.n_envs, self.action_dim))
Expand Down Expand Up @@ -354,7 +354,7 @@ def __init__(
self.reset()

def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down Expand Up @@ -428,7 +428,7 @@ def add(
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
obs = obs.reshape((self.n_envs, *self.obs_shape))

# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))
Expand Down Expand Up @@ -528,11 +528,11 @@ def __init__(
self.optimize_memory_usage = optimize_memory_usage

self.observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}

Expand Down Expand Up @@ -699,7 +699,7 @@ def reset(self) -> None:
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.observations = {}
for key, obs_input_shape in self.obs_shape.items():
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def _on_step(self) -> bool:

class EveryNTimesteps(EventCallback):
"""
Trigger a callback every ``n_steps`` timesteps
Trigger a callback every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
:param callback: Callback that will be called
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape)
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))

elif is_image_space(self.observation_space):
# Handle the different cases for images
Expand All @@ -242,7 +242,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1,) + self.observation_space.shape)
observation = observation.reshape((-1, *self.observation_space.shape))

observation = obs_as_tensor(observation, self.device)
return observation, vectorized_env
Expand Down Expand Up @@ -330,7 +330,7 @@ def predict(
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down Expand Up @@ -608,7 +608,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1,) + self.action_space.shape)
actions = actions.reshape((-1, *self.action_space.shape))
return actions, values, log_prob

def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/results_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
:return: rolling window on the input array
"""
shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
strides = array.strides + (array.strides[-1],)
strides = (*array.strides, array.strides[-1])
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)


Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def recursive_getattr(obj: Any, attr: str, *args) -> Any:
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))
return functools.reduce(_getattr, [obj, *attr.split(".")])


def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]):
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)

self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype)
self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype)
self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype)
else:
raise TypeError(
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def learn(
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["q_net", "q_net_target"]
return [*super()._excluded_save_params(), "q_net", "q_net_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
Expand Down
12 changes: 6 additions & 6 deletions stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def __init__(

# input dimensions for buffer initialization
input_shape = {
"observation": (self.env.num_envs,) + self.obs_shape,
"achieved_goal": (self.env.num_envs,) + self.goal_shape,
"desired_goal": (self.env.num_envs,) + self.goal_shape,
"observation": (self.env.num_envs, *self.obs_shape),
"achieved_goal": (self.env.num_envs, *self.goal_shape),
"desired_goal": (self.env.num_envs, *self.goal_shape),
"action": (self.action_dim,),
"reward": (1,),
"next_obs": (self.env.num_envs,) + self.obs_shape,
"next_achieved_goal": (self.env.num_envs,) + self.goal_shape,
"next_desired_goal": (self.env.num_envs,) + self.goal_shape,
"next_obs": (self.env.num_envs, *self.obs_shape),
"next_achieved_goal": (self.env.num_envs, *self.goal_shape),
"next_desired_goal": (self.env.num_envs, *self.goal_shape),
"done": (1,),
}
self._observation_keys = ["observation", "achieved_goal", "desired_goal"]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def learn(
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def learn(
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a8
1.8.0a9
2 changes: 1 addition & 1 deletion tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_report_video_to_tensorboard(tmp_path, read_log, capsys):

def is_moviepy_installed():
try:
import moviepy # noqa: F401
import moviepy
except ModuleNotFoundError:
return False
return True
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def test_is_vectorized_observation():
# pass
# All vectorized
box_space = spaces.Box(-1, 1, shape=(2,))
box_obs = np.ones((1,) + box_space.shape)
box_obs = np.ones((1, *box_space.shape))
assert is_vectorized_observation(box_obs, box_space)

discrete_space = spaces.Discrete(2)
Expand Down Expand Up @@ -485,13 +485,13 @@ def test_is_vectorized_observation():
# Vectorized with the wrong shape
with pytest.raises(ValueError):
discrete_obs = np.ones((1,), dtype=np.int8)
box_obs = np.ones((1, 2) + box_space.shape)
box_obs = np.ones((1, 2, *box_space.shape))
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)

# Weird shape: error
with pytest.raises(ValueError):
discrete_obs = np.ones((1,) + box_space.shape, dtype=np.int8)
discrete_obs = np.ones((1, *box_space.shape), dtype=np.int8)
is_vectorized_observation(discrete_obs, discrete_space)

# wrong shape
Expand All @@ -506,7 +506,7 @@ def test_is_vectorized_observation():

# Almost good shape: one dimension too much for Discrete obs
with pytest.raises(ValueError):
box_obs = np.ones((1,) + box_space.shape)
box_obs = np.ones((1, *box_space.shape))
discrete_obs = np.ones((1, 1), dtype=np.int8)
dict_obs = {"box": box_obs, "discrete": discrete_obs}
is_vectorized_observation(dict_obs, dict_space)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ def test_framestack_vecenv():
"""Test that framestack environment stacks on desired axis"""

image_space_shape = [12, 8, 3]
zero_acts = np.zeros([N_ENVS] + image_space_shape)
zero_acts = np.zeros([N_ENVS, *image_space_shape])

transposed_image_space_shape = image_space_shape[::-1]
transposed_zero_acts = np.zeros([N_ENVS] + transposed_image_space_shape)
transposed_zero_acts = np.zeros([N_ENVS, *transposed_image_space_shape])

def make_image_env():
return CustomGymEnv(
Expand Down

0 comments on commit 470771b

Please sign in to comment.