Skip to content

Commit

Permalink
Fixes replay buffer device after loading in OffPolicyAlgorithm (#1662)
Browse files Browse the repository at this point in the history
* sets replay buffer device after loading

* update changelog

* update changelog

* correct changelog

* add test for replay buffer device

* Fix test to actually test the bug fix

* [ci skip] Update version

* [ci skip] Update docker images

---------

Co-authored-by: PatrickHelm <patrick.helm@gmx.net>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
3 people committed Sep 3, 2023
1 parent 16c6a88 commit e071796
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG PARENT_IMAGE
FROM $PARENT_IMAGE
ARG PYTORCH_DEPS=cpuonly
ARG PYTHON_VERSION=3.8
ARG PYTHON_VERSION=3.10
ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found)

# Install micromamba env and dependencies
Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a0 (WIP)
Release 2.2.0a1 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -25,6 +25,7 @@ Bug Fixes:
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)


Deprecations:
Expand All @@ -37,7 +38,7 @@ Others:
- Fixed ``stable_baselines3/common/vec_envs/vec_transpose.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints
- Fixed ``stable_baselines3/common/save_util.py`` type hints

- Updated docker images to Ubuntu Jammy using micromamba 1.5

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions scripts/build_docker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

CPU_PARENT=mambaorg/micromamba:1.4-kinetic
GPU_PARENT=mambaorg/micromamba:1.4.1-focal-cuda-11.7.1
CPU_PARENT=mambaorg/micromamba:1.5-jammy
GPU_PARENT=mambaorg/micromamba:1.5-jammy-cuda-11.7.1

TAG=stablebaselines/stable-baselines3
VERSION=$(cat ./stable_baselines3/version.txt)
Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def load_replay_buffer(
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()

# Update saved replay buffer device to match current setting, see GH#1561
self.replay_buffer.device = self.device

def _setup_learn(
self,
total_timesteps: int,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a0
2.2.0a1
7 changes: 6 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,12 @@ def test_save_load_replay_buffer(tmp_path, model_class):
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(path)
model.replay_buffer = None
model.load_replay_buffer(path)
for device in ["cpu", "cuda"]:
# Manually force device to check that the replay buffer device
# is correctly updated
model.device = th.device(device)
model.load_replay_buffer(path)
assert model.replay_buffer.device.type == model.device.type

assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations)
assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions)
Expand Down

0 comments on commit e071796

Please sign in to comment.