Skip to content

Commit

Permalink
save_replay_buffer now receives as argument the file path instead of …
Browse files Browse the repository at this point in the history
…the folder path (#63)

* save_replay_buffer now receives as argument the file path instead of the folder path

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Tirafesi and araffin committed Jun 17, 2020
1 parent a861f33 commit 644d2c1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Pre-Release 0.8.0a0 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -325,3 +326,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi
4 changes: 2 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def save_replay_buffer(self, path: str):
"""
Save the replay buffer as a pickle file.
:param path: (str) Path to a log folder
:param path: (str) Path to the file where the replay buffer should be saved
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
with open(path, 'wb') as file_handler:
pickle.dump(self.replay_buffer, file_handler)

def load_replay_buffer(self, path: str):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_save_load_replay_buffer(model_class):
model = model_class('MlpPolicy', 'Pendulum-v0', buffer_size=1000)
model.learn(500)
old_replay_buffer = deepcopy(model.replay_buffer)
model.save_replay_buffer(log_folder)
model.save_replay_buffer(replay_path)
model.replay_buffer = None
model.load_replay_buffer(replay_path)

Expand Down

0 comments on commit 644d2c1

Please sign in to comment.