Skip to content

Commit

Permalink
Automatically create paths for saved objects (#80)
Browse files Browse the repository at this point in the history
* automatically create paths for saved objects

* Minor Corrections, more tests

* linting

* typing

* Correct mode checking

* corrected tests to reflect new verbose functionality
  • Loading branch information
m-rph committed Jul 2, 2020
1 parent 7d8ebb9 commit 4aa66ed
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 96 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ New Features:
- Buffer dtype is now set according to action and observation spaces for ``ReplayBuffer``
- Added warning when allocation of a buffer may exceed the available memory of the system
when ``psutil`` is available
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)

Bug Fixes:
^^^^^^^^^^
Expand Down
13 changes: 10 additions & 3 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
import pathlib
import io

import gym
import torch as th
Expand Down Expand Up @@ -291,7 +293,7 @@ def predict(self, observation: np.ndarray,
return self.policy.predict(observation, state, mask, deterministic)

@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm':
"""
Load the model from a zip-file
Expand Down Expand Up @@ -475,11 +477,16 @@ def excluded_save_params(self) -> List[str]:
"""
return ["policy", "device", "env", "eval_env", "replay_buffer", "rollout_buffer", "_vec_normalize_env"]

def save(self, path: str, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None) -> None:
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[List[str]] = None,
include: Optional[List[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param (Union[str, pathlib.Path, io.BufferedIOBase]): path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default one
:param include: name of parameters that might be excluded but should be included anyway
"""
Expand Down
21 changes: 11 additions & 10 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time
import pickle
import warnings
import pathlib
from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple
import io

import gym
import torch as th
Expand All @@ -16,6 +17,7 @@
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl


class OffPolicyAlgorithm(BaseAlgorithm):
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(self,
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup

def _setup_model(self):
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
Expand All @@ -136,24 +138,23 @@ def _setup_model(self):
self.lr_schedule, **self.policy_kwargs)
self.policy = self.policy.to(self.device)

def save_replay_buffer(self, path: str):
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Save the replay buffer as a pickle file.
:param path: (str) Path to the file where the replay buffer should be saved
:param path: (Union[str,pathlib.Path, io.BufferedIOBase]) Path to the file where the replay buffer should be saved.
if path is a str or pathlib.Path, the path is automatically created if necessary.
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
with open(path, 'wb') as file_handler:
pickle.dump(self.replay_buffer, file_handler)
save_to_pkl(path, self.replay_buffer, self.verbose)

def load_replay_buffer(self, path: str):
def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Load a replay buffer from a pickle file.
:param path: (str) Path to the pickled replay buffer.
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.
"""
with open(path, 'rb') as file_handler:
self.replay_buffer = pickle.load(file_handler)
self.replay_buffer = load_from_pkl(path, self.verbose)
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'

def _setup_learn(self,
Expand Down

0 comments on commit 4aa66ed

Please sign in to comment.