Skip to content

Commit

Permalink
Release 2.2.1: Hotfix file closing (#1754)
Browse files Browse the repository at this point in the history
* new closing policy

* revert #1742

* Add tests and update changelog

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Nov 17, 2023
1 parent e1eac84 commit e3dea4b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 33 deletions.
13 changes: 11 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
Changelog
==========

Release 2.2.0 (2023-11-16)
Release 2.2.1 (2023-11-17)
--------------------------
**Support for options at reset, bug fixes and better error messages**

.. note::

SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 <https://github.com/DLR-RM/stable-baselines3/issues/1751>`_.
Please use SB3 v2.2.1 and not v2.2.0.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
Expand All @@ -32,7 +38,9 @@ Bug Fixes:
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed)
- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically,
the behavior stay the same for tempfiles (they need to be closed manually),
the behavior is now consistent when loading/saving replay buffer

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -76,6 +84,7 @@ Others:
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types
- Added tests to check consistency when saving/loading files

Documentation:
^^^^^^^^^^^^^^
Expand Down
69 changes: 39 additions & 30 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,31 @@ def save_to_zip_file(
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
# data/params can be None, so do not
# try to serialize them blindly
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])

if isinstance(save_path, (str, pathlib.Path)):
file.close()


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
Expand All @@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()


def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
Expand All @@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj


def load_from_zip_file(
Expand Down Expand Up @@ -391,14 +399,14 @@ def load_from_zip_file(
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
file = open_path(load_path, "r", verbose=verbose, suffix="zip")

# set device to cpu if cuda is not available
device = get_device(device=device)

# Open the zip archive and load data
try:
with zipfile.ZipFile(load_path) as archive:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
Expand Down Expand Up @@ -451,5 +459,6 @@ def load_from_zip_file(
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
load_path.close()
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0
2.2.1
27 changes: 27 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import pathlib
import tempfile
import warnings
import zipfile
from collections import OrderedDict
Expand Down Expand Up @@ -752,7 +753,33 @@ def test_dqn_target_update_interval(tmp_path):
# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# Check behavior of save/load
# see https://github.com/DLR-RM/stable-baselines3/issues/1751

# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")

PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole"))
PPO.load(str(tmp_path / "dqn_cartpole"))

# Do the same but in memory, should not close the file
with tempfile.TemporaryFile() as fp:
PPO("MlpPolicy", "CartPole-v1").save(fp)
PPO.load(fp)
assert not fp.closed

# Same but with replay buffer
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")

model.save_replay_buffer(str(tmp_path / "replay"))
model.load_replay_buffer(str(tmp_path / "replay"))

with tempfile.TemporaryFile() as fp:
model.save_replay_buffer(fp)
fp.seek(0)
model.load_replay_buffer(fp)
assert not fp.closed

0 comments on commit e3dea4b

Please sign in to comment.