Skip to content

Commit

Permalink
Fix type hints for callbacks, utils and VecTranspose (#1648)
Browse files Browse the repository at this point in the history
* Fix type hints in `common/utils.py`

* Fix `VecTranspose` type annotations

* Fix types for callbacks

* Update changelog

* Fix video recorder type hints

* Fix save utils type hints

* Allow BytesIO

* Improve error message

* Make logger and training env properties

* Clarify which open_path fn is called
  • Loading branch information
araffin committed Aug 29, 2023
1 parent f4ec0f6 commit e9f0f23
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 58 deletions.
15 changes: 5 additions & 10 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,18 @@ A child callback is for instance :ref:`StopTrainingOnRewardThreshold <StopTraini
"""
Base class for triggering callback on event.
:param callback: (Optional[BaseCallback]) Callback that will be called
when an event is triggered.
:param callback: Callback that will be called when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super(EventCallback, self).__init__(verbose=verbose)
def __init__(self, callback: BaseCallback, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
self.callback.parent = self
...
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback()
return True
return self.callback()
Callback Collection
Expand Down
34 changes: 34 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,40 @@
Changelog
==========


Release 2.2.0a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Fixed ``stable_baselines3/common/callbacks.py`` type hints
- Fixed ``stable_baselines3/common/utils.py`` type hints
- Fixed ``stable_baselines3/common/vec_envs/vec_transpose.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints
- Fixed ``stable_baselines3/common/save_util.py`` type hints

Documentation:
^^^^^^^^^^^^^^


Release 2.1.0 (2023-08-17)
--------------------------

Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,11 @@ follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
| stable_baselines3/her/her_replay_buffer.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
Expand Down
49 changes: 32 additions & 17 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ class BaseCallback(ABC):
# The RL model
# Type hint as string to avoid circular import
model: "base_class.BaseAlgorithm"
logger: Logger

def __init__(self, verbose: int = 0):
super().__init__()
# An alias for self.model.get_env(), the environment used for training
self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
Expand All @@ -51,15 +48,25 @@ def __init__(self, verbose: int = 0):
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]

@property
def training_env(self) -> VecEnv:
training_env = self.model.get_env()
assert (
training_env is not None
), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks"
return training_env

@property
def logger(self) -> Logger:
return self.model.logger

# Type hint as string to avoid circular import
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
"""
self.model = model
self.training_env = model.get_env()
self.logger = model.logger
self._init_callback()

def _init_callback(self) -> None:
Expand Down Expand Up @@ -147,6 +154,7 @@ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
self.callback = callback
# Give access to the parent
if callback is not None:
assert self.callback is not None
self.callback.parent = self

def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
Expand Down Expand Up @@ -291,14 +299,14 @@ def _on_step(self) -> bool:
if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None:
# If model has a replay buffer, save it too
replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl")
self.model.save_replay_buffer(replay_buffer_path)
self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined]
if self.verbose > 1:
print(f"Saving model replay buffer checkpoint to {replay_buffer_path}")

if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None:
# Save the VecNormalize statistics
vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl")
self.model.get_vec_normalize_env().save(vec_normalize_path)
self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr]
if self.verbose >= 2:
print(f"Saving model VecNormalize to {vec_normalize_path}")

Expand Down Expand Up @@ -382,20 +390,20 @@ def __init__(

# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env])
eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value]

self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
if log_path is not None:
log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results = []
self.evaluations_timesteps = []
self.evaluations_length = []
self.evaluations_results: List[List[float]] = []
self.evaluations_timesteps: List[int] = []
self.evaluations_length: List[List[int]] = []
# For computing success rate
self._is_success_buffer = []
self.evaluations_successes = []
self._is_success_buffer: List[bool] = []
self.evaluations_successes: List[List[bool]] = []

def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
Expand Down Expand Up @@ -458,6 +466,8 @@ def _on_step(self) -> bool:
)

if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
Expand All @@ -478,7 +488,7 @@ def _on_step(self) -> bool:

mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
self.last_mean_reward = float(mean_reward)

if self.verbose >= 1:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
Expand All @@ -502,7 +512,7 @@ def _on_step(self) -> bool:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()
Expand Down Expand Up @@ -536,12 +546,14 @@ class StopTrainingOnRewardThreshold(BaseCallback):
threshold reached
"""

parent: EvalCallback

def __init__(self, reward_threshold: float, verbose: int = 0):
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold

def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
Expand Down Expand Up @@ -630,6 +642,8 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
"""

parent: EvalCallback

def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
Expand Down Expand Up @@ -666,6 +680,8 @@ class ProgressBarCallback(BaseCallback):
using tqdm and rich packages.
"""

pbar: tqdm # pytype: disable=invalid-annotation

def __init__(self) -> None:
super().__init__()
if tqdm is None:
Expand All @@ -674,7 +690,6 @@ def __init__(self) -> None:
"It is included if you install stable-baselines with the extra packages: "
"`pip install stable-baselines3[extra]`"
)
self.pbar = None

def _on_training_start(self) -> None:
# Initialize progress bar
Expand Down
28 changes: 16 additions & 12 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No


@functools.singledispatch
def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None):
def open_path(
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
Expand All @@ -201,18 +203,21 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
is not None, we attempt to open the path with the suffix.
:return:
"""
if not isinstance(path, io.BufferedIOBase):
raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
# Note(antonin): the true annotation should be IO[bytes]
# but there is not easy way to check that
allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO)
if not isinstance(path, allowed_types):
raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.")
if path.closed:
raise ValueError("File stream is closed.")
raise ValueError(f"File stream {path} is closed.")
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
e1 = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {e1} file.")
error_msg = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {error_msg} file.")
return path


Expand All @@ -231,7 +236,7 @@ def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str]
is not None, we attempt to open the path with the suffix.
:return:
"""
return open_path(pathlib.Path(path), mode, verbose, suffix)
return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix)


@open_path.register(pathlib.Path)
Expand All @@ -255,7 +260,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O

if mode == "r":
try:
path = path.open("rb")
return open_path(path.open("rb"), mode, verbose, suffix)
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
Expand All @@ -270,20 +275,19 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
path = pathlib.Path(f"{path}.{suffix}")
if path.exists() and path.is_file() and verbose >= 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
path = path.open("wb")
return open_path(path.open("wb"), mode, verbose, suffix)
except IsADirectoryError:
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
path = pathlib.Path(f"{path}_2")
except FileNotFoundError: # Occurs when the parent folder doesn't exist
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
path.parent.mkdir(exist_ok=True, parents=True)

# if opening was successful uses the identity function
# if opening was successful uses the open_path() function
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
# with corrections
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix

return open_path(path, mode, verbose, suffix)
return open_path_pathlib(path, mode, verbose, suffix)


def save_to_zip_file(
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
SummaryWriter = None # type: ignore[misc, assignment]

from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
Expand Down Expand Up @@ -396,21 +396,21 @@ def is_vectorized_observation(observation: Union[int, np.ndarray], observation_s

for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space)
return is_vec_obs_func(observation, observation_space) # type: ignore[operator]
else:
# for-else happens if no break is called
raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")


def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
def safe_mean(arr: Union[np.ndarray, list, deque]) -> float:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]


def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
Expand Down

0 comments on commit e9f0f23

Please sign in to comment.