Skip to content

Commit

Permalink
Hotfix to load policies saved with SB3 <= v1.6 (#1234)
Browse files Browse the repository at this point in the history
* Hotfix to load policies saved with SB3 <= v1.6

* Add warning and test

* Update doc
  • Loading branch information
araffin committed Dec 22, 2022
1 parent 3c028f3 commit e78ba6f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
10 changes: 9 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@ Changelog
==========


Release 1.7.0a9 (WIP)
Release 1.7.0a10 (WIP)
--------------------------

.. note::

A2C and PPO saved with SB3 < 1.7.0 will show a warning about
missing keys in the state dict when loaded with SB3 >= 1.7.0.
To suppress the warning, simply save the model again.
You can find more info in `issue #1233 <https://github.com/DLR-RM/stable-baselines3/issues/1233>`_


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
Expand Down
22 changes: 20 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import pathlib
import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -705,8 +706,25 @@ def load(
model.__dict__.update(kwargs)
model._setup_model()

# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
try:
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load Policy saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e

# put other pytorch variables back in place
if pytorch_variables is not None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a9
1.7.0a10
8 changes: 8 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ def test_save_load_env_cnn(tmp_path, model_class):
# clear file from os
os.remove(tmp_path / "test_save.zip")

# Check we can load models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
model_class.load(str(tmp_path / "test_save.zip"), env=env, **kwargs).learn(100)
os.remove(tmp_path / "test_save.zip")


@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
def test_save_load_replay_buffer(tmp_path, model_class):
Expand Down

0 comments on commit e78ba6f

Please sign in to comment.