Skip to content

Commit

Permalink
Type annotation bundle (logger, vec env, custom envs) (#1479)
Browse files Browse the repository at this point in the history
* Switch from List to Sequence for `seed()` type hint

* Fix logger type hints

* Improve replay buffer type hints

* Fix custom envs type annotations

* Fix VecMonitor type hints

* Fix RMSprop type hint

* Fix vec extract dict obs type hints

* Fix vec frame stack type annotations

* Fix base vec env type hints

* Fix dummy vec env type hints

* Fix for mypy

* Fixes for the tests

* mypy doesn't like when we overwrite type

* fix step of SimpleMultiObsEnv

* remove useless type specification

* Rm useless type hint

* Improve logger type hint

* format

* rm useless type hint

* Re-add variables in constructor, remove unused import

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
araffin and qgallouedec committed May 4, 2023
1 parent d6ddee9 commit 63a0bb9
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 96 deletions.
9 changes: 8 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a6 (WIP)
Release 2.0.0a7 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -48,12 +48,19 @@ Others:
- Fixed ``stable_baselines3/sac/*.py`` type hints
- Fixed ``stable_baselines3/td3/*.py`` type hints
- Fixed ``stable_baselines3/common/base_class.py`` type hints
- Fixed ``stable_baselines3/common/logger.py`` type hints
- Fixed ``stable_baselines3/common/envs/*.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
- Tests envs are now checked too
- Added render test for ``VecEnv``
- Update issue templates and env info saved with the model
- Changed ``seed()`` method return type from ``List`` to ``Sequence``

Documentation:
^^^^^^^^^^^^^^
Expand Down
11 changes: 0 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,12 @@ exclude = """(?x)(
stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/envs/bit_flipping_env.py$
| stable_baselines3/common/envs/identity_env.py$
| stable_baselines3/common/envs/multi_input_envs.py$
| stable_baselines3/common/logger.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
| stable_baselines3/common/vec_env/vec_frame_stack.py$
| stable_baselines3/common/vec_env/vec_monitor.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,8 @@ class DictRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""

observations: Dict[str, np.ndarray]

def __init__(
self,
buffer_size: int,
Expand All @@ -697,8 +699,7 @@ def __init__(

self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None

self.generator_ready = False
self.reset()

Expand Down
26 changes: 14 additions & 12 deletions stable_baselines3/common/envs/bit_flipping_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BitFlippingEnv(Env):
"""

spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
state: np.ndarray

def __init__(
self,
Expand All @@ -35,8 +36,10 @@ def __init__(
discrete_obs_space: bool = False,
image_obs_space: bool = False,
channel_first: bool = True,
render_mode: str = "human",
):
super().__init__()
self.render_mode = render_mode
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
Expand Down Expand Up @@ -95,7 +98,6 @@ def __init__(
self.continuous = continuous
self.discrete_obs_space = discrete_obs_space
self.image_obs_space = image_obs_space
self.state = None
self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype)
if max_steps is None:
max_steps = n_bits
Expand Down Expand Up @@ -127,21 +129,20 @@ def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int)
"""
Convert to bit vector if needed.
:param state:
:param batch_size:
:return:
:param state: The state to be converted, which can be either an integer or a numpy array.
:param batch_size: The batch size.
:return: The state converted into a bit vector.
"""
# Convert back to bit vector
if isinstance(state, int):
state = np.array(state).reshape(batch_size, -1)
bit_vector = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
else:
state = np.array(state).reshape(batch_size, -1)

return state
bit_vector = np.array(state).reshape(batch_size, -1)
return bit_vector

def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
"""
Expand Down Expand Up @@ -205,10 +206,11 @@ def compute_reward(
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)

def render(self, mode: str = "human") -> Optional[np.ndarray]:
if mode == "rgb_array":
def render(self) -> Optional[np.ndarray]: # type: ignore[override]
if self.render_mode == "rgb_array":
return self.state.copy()
print(self.state)
return None

def close(self) -> None:
pass
14 changes: 6 additions & 8 deletions stable_baselines3/common/envs/multi_input_envs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
self.init_possible_transitions()

self.num_col = num_col
self.state_mapping = []
self.state_mapping: List[Dict[str, np.ndarray]] = []
self.init_state_mapping(num_col, num_row)

self.max_state = len(self.state_mapping) - 1
Expand Down Expand Up @@ -121,20 +121,18 @@ def init_possible_transitions(self) -> None:
self.right_possible = [0, 1, 2, 12, 13, 14]
self.up_possible = [4, 8, 12, 7, 11, 15]

def step(self, action: Union[float, np.ndarray]) -> GymStepReturn:
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
"""
Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.
Accepts an action and returns a tuple (observation, reward, done, info).
Accepts an action and returns a tuple (observation, reward, terminated, truncated, info).
:param action:
:return: tuple (observation, reward, done, info).
:return: tuple (observation, reward, terminated, truncated, info).
"""
if not self.discrete_actions:
action = np.argmax(action)
else:
action = int(action)
action = np.argmax(action) # type: ignore[assignment]

self.count += 1

Expand Down
69 changes: 40 additions & 29 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import warnings
from collections import defaultdict
from io import TextIOWrapper
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -113,7 +114,7 @@ class KVWriter:
Key Value writer
"""

def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
"""
Write a dictionary to file
Expand All @@ -135,7 +136,7 @@ class SeqWriter:
sequence writer
"""

def write_sequence(self, sequence: List) -> None:
def write_sequence(self, sequence: List[str]) -> None:
"""
write_sequence an array to file
Expand Down Expand Up @@ -163,15 +164,16 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
else:
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
elif isinstance(filename_or_file, TextIOWrapper): # equivalent to `isinstance(..., TextIO)` (not supported)
self.file = filename_or_file
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")

def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Create strings for printing
key2str = {}
tag = None
tag = ""
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
Expand All @@ -197,9 +199,9 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
if key.find("/") > 0: # Find tag and add it to the dict
tag = key[: key.find("/") + 1]
key2str[(tag, self._truncate(tag))] = ""
# Remove tag from key
if tag is not None and tag in key:
key = str(" " + key[len(tag) :])
# Remove tag from key and indent the key
if len(tag) > 0 and tag in key:
key = f"{'':3}{key[len(tag) :]}"

truncated_key = self._truncate(key)
if (tag, truncated_key) in key2str:
Expand Down Expand Up @@ -240,8 +242,7 @@ def _truncate(self, string: str) -> str:
string = string[: self.max_length - 3] + "..."
return string

def write_sequence(self, sequence: List) -> None:
sequence = list(sequence)
def write_sequence(self, sequence: List[str]) -> None:
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
Expand All @@ -257,9 +258,7 @@ def close(self) -> None:
self.file.close()


def filter_excluded_keys(
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
) -> Dict[str, Any]:
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
Expand All @@ -285,7 +284,7 @@ class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
self.file = open(filename, "w")

def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):
if isinstance(value, Video):
raise FormatUnsupportedError(["json"], "video")
Expand Down Expand Up @@ -332,7 +331,7 @@ def __init__(self, filename: str):
self.separator = ","
self.quotechar = '"'

def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
Expand Down Expand Up @@ -394,10 +393,12 @@ class TensorBoardOutputFormat(KVWriter):
"""

def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so"
self.writer = SummaryWriter(log_dir=folder)
self._is_closed = False

def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "tensorboard" in excluded:
continue
Expand Down Expand Up @@ -437,7 +438,7 @@ def close(self) -> None:
"""
if self.writer:
self.writer.close()
self.writer = None
self._is_closed = True


def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
Expand Down Expand Up @@ -478,13 +479,24 @@ class Logger:
"""

def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
self.name_to_value = defaultdict(float) # values this iteration
self.name_to_count = defaultdict(int)
self.name_to_excluded = defaultdict(str)
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: Dict[str, int] = defaultdict(int)
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
self.level = INFO
self.dir = folder
self.output_formats = output_formats

@staticmethod
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
"""
Helper function to convert str to tuple of str.
"""
if string_or_tuple is None:
return ("",)
if isinstance(string_or_tuple, tuple):
return string_or_tuple
return (string_or_tuple,)

def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Expand All @@ -496,9 +508,9 @@ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, .
:param exclude: outputs to be excluded
"""
self.name_to_value[key] = value
self.name_to_excluded[key] = exclude
self.name_to_excluded[key] = self.to_tuple(exclude)

def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
Expand All @@ -507,12 +519,11 @@ def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[s
:param exclude: outputs to be excluded
"""
if value is None:
self.name_to_value[key] = None
return
old_val, count = self.name_to_value[key], self.name_to_count[key]
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
self.name_to_count[key] = count + 1
self.name_to_excluded[key] = exclude
self.name_to_excluded[key] = self.to_tuple(exclude)

def dump(self, step: int = 0) -> None:
"""
Expand Down Expand Up @@ -592,7 +603,7 @@ def set_level(self, level: int) -> None:
"""
self.level = level

def get_dir(self) -> str:
def get_dir(self) -> Optional[str]:
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
Expand All @@ -610,15 +621,15 @@ def close(self) -> None:

# Misc
# ----------------------------------------
def _do_log(self, args) -> None:
def _do_log(self, args: Tuple[Any, ...]) -> None:
"""
log to the requested format outputs
:param args: the arguments to log
"""
for _format in self.output_formats:
if isinstance(_format, SeqWriter):
_format.write_sequence(map(str, args))
_format.write_sequence(list(map(str, args)))


def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
group.setdefault("centered", False)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]:
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
:param closure: A closure that reevaluates the model
Expand Down

0 comments on commit 63a0bb9

Please sign in to comment.