Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gymnasium to v1.0.0 #1837

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
08e5f9a
Update Gymnasium to v1.0.0a1
pseudo-rnd-thoughts Feb 13, 2024
f73c08e
Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord)
pseudo-rnd-thoughts Feb 13, 2024
08d3ac9
Fix ruff warnings
pseudo-rnd-thoughts Feb 13, 2024
eb55500
Register Atari envs
pseudo-rnd-thoughts Feb 13, 2024
686d1a0
Update `getattr` to `Env.get_wrapper_attr`
pseudo-rnd-thoughts Feb 13, 2024
da48aed
Reorder imports
pseudo-rnd-thoughts Feb 13, 2024
b063f94
Fix `seed` order
pseudo-rnd-thoughts Feb 13, 2024
6e11f93
Fix collecting `max_steps`
pseudo-rnd-thoughts Feb 13, 2024
7958dba
Merge branch 'master' into gymnasium-1.0.0a1
araffin Feb 19, 2024
d7ed302
Merge branch 'master' into gymnasium-1.0.0a1
araffin Mar 4, 2024
39f0900
Copy and paste video recorder to prevent the need to rewrite the vec …
pseudo-rnd-thoughts Apr 3, 2024
2f403da
Use `typing.List` rather than list
pseudo-rnd-thoughts Apr 3, 2024
1f8c554
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts Apr 3, 2024
c32e198
Fix env attribute forwarding
pseudo-rnd-thoughts Apr 3, 2024
34637a5
Separate out env attribute collection from its utilisation
pseudo-rnd-thoughts Apr 4, 2024
0f52339
Merge branch 'master' into gymnasium-1.0.0a1
araffin Apr 8, 2024
79e6e1d
Merge branch 'master' into gymnasium-1.0.0a1
araffin Apr 22, 2024
a42a15e
Merge branch 'master' into gymnasium-1.0.0a1
araffin May 8, 2024
96abd7d
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts May 21, 2024
aadb895
Update for Gymnasium alpha 2
pseudo-rnd-thoughts May 21, 2024
0890cd4
Remove assert for OrderedDict
pseudo-rnd-thoughts May 21, 2024
b1e15b4
Merge branch 'master' into gymnasium-1.0.0a1
araffin Jun 10, 2024
eef7cfd
Merge branch 'master' into gymnasium-1.0.0a1
araffin Jun 29, 2024
e5b7104
Update setup.py
araffin Jun 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"gymnasium==1.0.0a1",
"numpy>=1.20",
"torch>=1.13",
# For saving models
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _obs_from_buf(self) -> VecEnvObs:
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]

def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
Expand All @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def _worker(
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = getattr(env, data[0])
method = env.get_wrapper_attr(data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(getattr(env, data))
remote.send(env.get_wrapper_attr(data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
Expand Down
177 changes: 171 additions & 6 deletions stable_baselines3/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,180 @@
import json
import os
from typing import Callable
import os.path
import tempfile
from typing import Callable, List, Optional

from gymnasium.wrappers.monitoring import video_recorder
import numpy as np
from gymnasium import error, logger

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv


# This is copy and pasted from Gymnasium v0.26.1
class VideoRecorder:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the equivalent in gym v1.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""VideoRecorder renders a nice movie of a rollout, frame by frame.

It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.

Note:
You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process.
"""

def __init__(
self,
env,
path: Optional[str] = None,
metadata: Optional[dict] = None,
enabled: bool = True,
base_path: Optional[str] = None,
):
"""Video recorder renders a nice movie of a rollout, frame by frame.

Args:
env (Env): Environment to take video of.
path (Optional[str]): Path to the video file; will be randomly chosen if omitted.
metadata (Optional[dict]): Contents to save to the metadata file.
enabled (bool): Whether to actually record video, or just no-op (for convenience)
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.

Raises:
Error: You can pass at most one of `path` or `base_path`
Error: Invalid path given that must have a particular file extension
"""
try:
# check that moviepy is now installed
import moviepy # noqa: F401
except ImportError as e:
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e

self._async = env.metadata.get("semantics.async")
self.enabled = enabled
self._closed = False

self.render_history: List[np.ndarray] = []
self.env = env

self.render_mode = env.render_mode

if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode:
logger.warn(
f"Disabling video recorder because environment {env} was not initialized with any compatible video "
"mode between `rgb_array` and `rgb_array_list`"
)
# Disable since the environment has not been initialized with a compatible `render_mode`
self.enabled = False

# Don't bother setting anything else if not enabled
if not self.enabled:
return

if path is not None and base_path is not None:
raise error.Error("You can pass at most one of `path` or `base_path`.")

required_ext = ".mp4"
if path is None:
if base_path is not None:
# Base path given, append ext
path = base_path + required_ext
else:
# Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile(suffix=required_ext) as f:
path = f.name
self.path = path

path_base, actual_ext = os.path.splitext(self.path)

if actual_ext != required_ext:
raise error.Error(f"Invalid path given: {self.path} -- must have file extension {required_ext}.")

self.frames_per_sec = env.metadata.get("render_fps", 30)

self.broken = False

# Dump metadata
self.metadata = metadata or {}
self.metadata["content_type"] = "video/mp4"
self.metadata_path = f"{path_base}.meta.json"
self.write_metadata()

logger.info(f"Starting new video recorder writing to {self.path}")
self.recorded_frames: List[np.ndarray] = []

@property
def functional(self):
"""Returns if the video recorder is functional, is enabled and not broken."""
return self.enabled and not self.broken

def capture_frame(self):
"""Render the given `env` and add the resulting frame to the video."""
frame = self.env.render()
if isinstance(frame, List):
self.render_history += frame
frame = frame[-1]

if not self.functional:
return
if self._closed:
logger.warn("The video recorder has been closed and no frames will be captured anymore.")
return
logger.debug("Capturing video frame: path=%s", self.path)

if frame is None:
if self._async:
return
else:
# Indicates a bug in the environment: don't want to raise
# an error here.
logger.warn(
"Env returned None on `render()`. Disabling further rendering for video recorder by marking as "
f"disabled: path={self.path} metadata_path={self.metadata_path}"
)
self.broken = True
else:
self.recorded_frames.append(frame)

def close(self):
"""Flush all data to disk and close any open frame encoders."""
if not self.enabled or self._closed:
return

# First close the environment
self.env.close()

# Close the encoder
if len(self.recorded_frames) > 0:
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e

logger.debug(f"Closing video encoder: path={self.path}")
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
clip.write_videofile(self.path)
else:
# No frames captured. Set metadata.
if self.metadata is None:
self.metadata = {}
self.metadata["empty"] = True

self.write_metadata()

# Stop tracking this for autoclose
self._closed = True

def write_metadata(self):
"""Writes metadata to metadata path."""
with open(self.metadata_path, "w") as f:
json.dump(self.metadata, f)

def __del__(self):
"""Closes the environment correctly when the recorder is deleted."""
# Make sure we've closed up shop when garbage collecting
self.close()


class VecVideoRecorder(VecEnvWrapper):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
Expand All @@ -22,7 +189,7 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: Prefix to the video name
"""

video_recorder: video_recorder.VideoRecorder
video_recorder: VideoRecorder

def __init__(
self,
Expand Down Expand Up @@ -73,9 +240,7 @@ def start_video_recorder(self) -> None:

video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
)
self.video_recorder = VideoRecorder(env=self.env, base_path=base_path, metadata={"step_id": self.step_id})

self.video_recorder.capture_frame()
self.recorded_frames = 1
Expand Down
3 changes: 1 addition & 2 deletions tests/test_dict_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ def test_consistency(model_class):
"""
use_discrete_actions = model_class == DQN
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env.seed(10)
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
obs, _ = dict_env.reset()

kwargs = {}
n_steps = 256

if model_class in {A2C, PPO}:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _on_rollout_end(self):
buffer = self.model.rollout_buffer
rollout_size = buffer.size()

max_steps = self.training_env.envs[0].max_steps
max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps")
gamma = self.model.gamma
gae_lambda = self.model.gae_lambda
value = self.model.policy.constant_value
Expand Down
8 changes: 4 additions & 4 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,14 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path):
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))

# Equip the model of a custom logger to check the success_rate info
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.env.steps_per_log, verbose=1)
araffin marked this conversation as resolved.
Show resolved Hide resolved
logger = InMemoryLogger()
model.set_logger(logger)

# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.3
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.5
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.8
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch as th
from gymnasium import spaces
from shimmy import registration

import stable_baselines3 as sb3
from stable_baselines3 import A2C
Expand All @@ -24,6 +25,9 @@
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

# a hack to get atari environment registered for 1.0.0 alpha 1
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
registration._register_atari_envs()


@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")])
@pytest.mark.parametrize("n_envs", [1, 2])
Expand Down
Loading