diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0efc16e56..676bf9cb3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,13 +34,7 @@ jobs: # cpu version of pytorch pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - - pip install .[extra_no_roms,tests,docs] + pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless - name: Lint with ruff diff --git a/pyproject.toml b/pyproject.toml index 8e20ffe00..eafb9b5bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"] # ClassVar, implicit optional check not needed for tests "./tests/*.py"= ["RUF012", "RUF013"] - [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 9d56dfd77..b89d59cba 100644 --- a/setup.py +++ b/setup.py @@ -70,37 +70,13 @@ """ # noqa:E501 -# Atari Games download is sometimes problematic: -# https://github.com/Farama-Foundation/AutoROM/issues/39 -# That's why we define extra packages without it. -extra_no_roms = [ - # For render - "opencv-python", - "pygame", - # Tensorboard support - "tensorboard>=2.9.1", - # Checking memory taken by replay buffer - "psutil", - # For progress bar callback - "tqdm", - "rich", - # For atari games, - "shimmy[atari]~=1.3.0", - "pillow", -] - -extra_packages = extra_no_roms + [ # noqa: RUF005 - # For atari roms, - "autorom[accept-rom-license]~=0.6.1", -] - setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<0.30", + "gymnasium>=1.0.0a1,<1.1.0", "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models @@ -133,8 +109,21 @@ # Copy button for code snippets "sphinx_copybutton", ], - "extra": extra_packages, - "extra_no_roms": extra_no_roms, + "extra": [ + # For render + "opencv-python", + "pygame", + # Tensorboard support + "tensorboard>=2.9.1", + # Checking memory taken by replay buffer + "psutil", + # For progress bar callback + "tqdm", + "rich", + # For atari games, + "ale-py>=0.9.0", + "pillow", + ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..a37d3f254 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -115,7 +115,7 @@ def _obs_from_buf(self) -> VecEnvObs: def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, attr_name) for env_i in target_envs] + return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..9ebd16c8f 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -54,10 +54,10 @@ def _worker( elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": - method = getattr(env, data[0]) + method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": - remote.send(getattr(env, data)) + remote.send(env.get_wrapper_attr(data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -221,7 +221,7 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Sp assert len(obs) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) elif isinstance(space, spaces.Tuple): diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 855f50edc..2bb66a295 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -19,7 +19,7 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: :param obs: a dict of numpy arrays. :return: a dict of copied numpy arrays. """ - assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" + assert isinstance(obs, dict), f"unexpected type for observations '{type(obs)}'" return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) @@ -60,7 +60,7 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 52faebd1f..bf2153e84 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,13 +1,180 @@ +import json import os -from typing import Callable +import os.path +import tempfile +from typing import Callable, List, Optional -from gymnasium.wrappers.monitoring import video_recorder +import numpy as np +from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +# This is copy and pasted from Gymnasium v0.26.1 +class VideoRecorder: + """VideoRecorder renders a nice movie of a rollout, frame by frame. + + It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video. + + Note: + You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process. + """ + + def __init__( + self, + env, + path: Optional[str] = None, + metadata: Optional[dict] = None, + enabled: bool = True, + base_path: Optional[str] = None, + ): + """Video recorder renders a nice movie of a rollout, frame by frame. + + Args: + env (Env): Environment to take video of. + path (Optional[str]): Path to the video file; will be randomly chosen if omitted. + metadata (Optional[dict]): Contents to save to the metadata file. + enabled (bool): Whether to actually record video, or just no-op (for convenience) + base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. + + Raises: + Error: You can pass at most one of `path` or `base_path` + Error: Invalid path given that must have a particular file extension + """ + try: + # check that moviepy is now installed + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e + + self._async = env.metadata.get("semantics.async") + self.enabled = enabled + self._closed = False + + self.render_history: List[np.ndarray] = [] + self.env = env + + self.render_mode = env.render_mode + + if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode: + logger.warn( + f"Disabling video recorder because environment {env} was not initialized with any compatible video " + "mode between `rgb_array` and `rgb_array_list`" + ) + # Disable since the environment has not been initialized with a compatible `render_mode` + self.enabled = False + + # Don't bother setting anything else if not enabled + if not self.enabled: + return + + if path is not None and base_path is not None: + raise error.Error("You can pass at most one of `path` or `base_path`.") + + required_ext = ".mp4" + if path is None: + if base_path is not None: + # Base path given, append ext + path = base_path + required_ext + else: + # Otherwise, just generate a unique filename + with tempfile.NamedTemporaryFile(suffix=required_ext) as f: + path = f.name + self.path = path + + path_base, actual_ext = os.path.splitext(self.path) + + if actual_ext != required_ext: + raise error.Error(f"Invalid path given: {self.path} -- must have file extension {required_ext}.") + + self.frames_per_sec = env.metadata.get("render_fps", 30) + + self.broken = False + + # Dump metadata + self.metadata = metadata or {} + self.metadata["content_type"] = "video/mp4" + self.metadata_path = f"{path_base}.meta.json" + self.write_metadata() + + logger.info(f"Starting new video recorder writing to {self.path}") + self.recorded_frames: List[np.ndarray] = [] + + @property + def functional(self): + """Returns if the video recorder is functional, is enabled and not broken.""" + return self.enabled and not self.broken + + def capture_frame(self): + """Render the given `env` and add the resulting frame to the video.""" + frame = self.env.render() + if isinstance(frame, List): + self.render_history += frame + frame = frame[-1] + + if not self.functional: + return + if self._closed: + logger.warn("The video recorder has been closed and no frames will be captured anymore.") + return + logger.debug("Capturing video frame: path=%s", self.path) + + if frame is None: + if self._async: + return + else: + # Indicates a bug in the environment: don't want to raise + # an error here. + logger.warn( + "Env returned None on `render()`. Disabling further rendering for video recorder by marking as " + f"disabled: path={self.path} metadata_path={self.metadata_path}" + ) + self.broken = True + else: + self.recorded_frames.append(frame) + + def close(self): + """Flush all data to disk and close any open frame encoders.""" + if not self.enabled or self._closed: + return + + # First close the environment + self.env.close() + + # Close the encoder + if len(self.recorded_frames) > 0: + try: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e + + logger.debug(f"Closing video encoder: path={self.path}") + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.path) + else: + # No frames captured. Set metadata. + if self.metadata is None: + self.metadata = {} + self.metadata["empty"] = True + + self.write_metadata() + + # Stop tracking this for autoclose + self._closed = True + + def write_metadata(self): + """Writes metadata to metadata path.""" + with open(self.metadata_path, "w") as f: + json.dump(self.metadata, f) + + def __del__(self): + """Closes the environment correctly when the recorder is deleted.""" + # Make sure we've closed up shop when garbage collecting + self.close() + + class VecVideoRecorder(VecEnvWrapper): """ Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. @@ -22,7 +189,7 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: video_recorder.VideoRecorder + video_recorder: VideoRecorder def __init__( self, @@ -73,9 +240,7 @@ def start_video_recorder(self) -> None: video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - ) + self.video_recorder = VideoRecorder(env=self.env, base_path=base_path, metadata={"step_id": self.step_id}) self.video_recorder.capture_frame() self.recorded_frames = 1 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index f093e47e7..8049c6887 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -117,12 +117,11 @@ def test_consistency(model_class): """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) obs, _ = dict_env.reset() - kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: diff --git a/tests/test_gae.py b/tests/test_gae.py index 83b95a4c0..bb674cffa 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -73,7 +73,7 @@ def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() - max_steps = self.training_env.envs[0].max_steps + max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value diff --git a/tests/test_logger.py b/tests/test_logger.py index dfa3691ed..82605db94 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -540,6 +540,7 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): """ STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, @@ -551,16 +552,17 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e9f..01227855b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import os import shutil +import ale_py import gymnasium as gym import numpy as np import pytest @@ -24,6 +25,8 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +gym.register_envs(ale_py) + @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2])