Skip to content

Commit

Permalink
Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)
Browse files Browse the repository at this point in the history
* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version
  • Loading branch information
araffin committed Apr 13, 2023
1 parent 3bc8918 commit 923bd46
Show file tree
Hide file tree
Showing 18 changed files with 170 additions and 108 deletions.
5 changes: 5 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Deprecations:

Others:
^^^^^^^
- Fixed ``stable_baselines3/a2c/*.py`` type hints
- Fixed ``stable_baselines3/ppo/*.py`` type hints
- Fixed ``stable_baselines3/sac/*.py`` type hints
- Fixed ``stable_baselines3/td3/*.py`` type hints
- Fixed ``stable_baselines3/common/base_class.py`` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
Expand Down
10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/a2c/a2c.py$
| stable_baselines3/common/base_class.py$
| stable_baselines3/common/buffers.py$
stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/envs/bit_flipping_env.py$
| stable_baselines3/common/envs/identity_env.py$
| stable_baselines3/common/envs/multi_input_envs.py$
| stable_baselines3/common/logger.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/on_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
Expand All @@ -62,11 +59,6 @@ exclude = """(?x)(
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
| stable_baselines3/her/her_replay_buffer.py$
| stable_baselines3/ppo/ppo.py$
| stable_baselines3/sac/policies.py$
| stable_baselines3/sac/sac.py$
| stable_baselines3/td3/policies.py$
| stable_baselines3/td3/td3.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""
Expand Down
73 changes: 43 additions & 30 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
Expand All @@ -44,21 +44,22 @@
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")


def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
"""If env is a string, make the environment; otherwise, return env.
:param env: The environment to learn from.
:param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
:return A Gym (vector) environment.
"""
if isinstance(env, str):
env_id = env
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
print(f"Creating environment from the given name '{env_id}'")
# Set render_mode to `rgb_array` as default, so we can record video
try:
env = gym.make(env, render_mode="rgb_array")
env = gym.make(env_id, render_mode="rgb_array")
except TypeError:
env = gym.make(env)
env = gym.make(env_id)
return env


Expand Down Expand Up @@ -95,6 +96,11 @@ class BaseAlgorithm(ABC):
# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
n_envs: int
lr_schedule: Schedule
_logger: Logger

def __init__(
self,
Expand All @@ -111,8 +117,8 @@ def __init__(
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
Expand All @@ -122,25 +128,19 @@ def __init__(
if verbose >= 1:
print(f"Using {self.device} device")

self.env = None # type: Optional[GymEnv]
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space: spaces.Space
self.action_space: spaces.Space
self.n_envs: int

self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
self.seed = seed
self.action_noise: Optional[ActionNoise] = None
self.start_time = None
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
Expand All @@ -151,17 +151,17 @@ def __init__(
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
# this is used to update the learning rate
self._current_progress_remaining = 1
self._current_progress_remaining = 1.0
# Buffers for logging
self._stats_window_size = stats_window_size
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
# For logging (and TD3 delayed updates)
self._n_updates = 0 # type: int
# The logger object
self._logger = None # type: Logger
# Whether the user passed a custom logger or not
self._custom_logger = False
self.env: Optional[VecEnv] = None
self._vec_normalize_env: Optional[VecNormalize] = None

# Create and wrap the env if needed
if env is not None:
Expand All @@ -173,6 +173,9 @@ def __init__(
self.n_envs = env.num_envs
self.env = env

# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)

if supported_action_spaces is not None:
assert isinstance(self.action_space, supported_action_spaces), (
f"The algorithm only supports {supported_action_spaces} as action spaces "
Expand Down Expand Up @@ -217,7 +220,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
env = Monitor(env)
if verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]

# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
Expand All @@ -230,11 +233,11 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
# the other channel last), VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space)
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space
env.observation_space # type: ignore[arg-type]
)

if wrap_with_vectranspose:
Expand Down Expand Up @@ -416,7 +419,10 @@ def _setup_learn(

# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
assert self.env is not None
# pytype: disable=annotation-type-mismatch
self._last_obs = self.env.reset() # type: ignore[assignment]
# pytype: enable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
Expand All @@ -439,6 +445,9 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd
:param infos: List of additional information about the transition.
:param dones: Termination signals
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
Expand Down Expand Up @@ -562,7 +571,7 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:

def set_parameters(
self,
load_path_or_dict: Union[str, Dict[str, Dict]],
load_path_or_dict: Union[str, TensorDict],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
Expand All @@ -578,7 +587,7 @@ def set_parameters(
can be used to update only specific parameters.
:param device: Device on which the code should run.
"""
params = None
params = {}
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
Expand Down Expand Up @@ -616,7 +625,7 @@ def set_parameters(
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name])
attr.load_state_dict(params[name]) # type: ignore[arg-type]
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
Expand Down Expand Up @@ -674,6 +683,9 @@ def load( # noqa: C901
print_system_info=print_system_info,
)

assert data is not None, "No data found in the saved file"
assert params is not None, "No params found in the saved file"

# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
Expand Down Expand Up @@ -714,13 +726,14 @@ def load( # noqa: C901
if "env" in data:
env = data["env"]

# noinspection PyArgumentList
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
# pytype: disable=not-instantiable,wrong-keyword-args
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
_init_setup_model=False, # type: ignore[call-arg]
)
# pytype: enable=not-instantiable,wrong-keyword-args

# load parameters
model.__dict__.update(data)
Expand Down Expand Up @@ -758,12 +771,12 @@ def load( # noqa: C901
continue
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
recursive_setattr(model, f"{name}.data", pytorch_variables[name].data)

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # pytype: disable=attribute-error
model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error
return model

def get_parameters(self) -> Dict[str, Dict]:
Expand Down
11 changes: 9 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ class RolloutBuffer(BaseBuffer):
:param n_envs: Number of parallel environments
"""

observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray

def __init__(
self,
buffer_size: int,
Expand All @@ -348,8 +357,6 @@ def __init__(
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class ConvertCallback(BaseCallback):
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""

def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
super().__init__(verbose)
self.callback = callback

Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param supported_action_spaces: The action spaces supported by the algorithm.
"""

actor: th.nn.Module

def __init__(
self,
policy: Union[str, Type[BasePolicy]],
Expand Down Expand Up @@ -129,15 +131,14 @@ def __init__(
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
self.replay_buffer: Optional[ReplayBuffer] = None
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self._episode_storage = None

# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq

self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer: Optional[ReplayBuffer] = None
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
Expand Down

0 comments on commit 923bd46

Please sign in to comment.