Skip to content

Commit

Permalink
Merge branch 'master' into feat/gymnasium-support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Apr 3, 2023
2 parents 5e1f507 + 5a70af8 commit aa1a64c
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 63 deletions.
15 changes: 0 additions & 15 deletions .coveragerc

This file was deleted.

3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.8.0a12 (WIP)
Release 1.8.0a13 (WIP)
--------------------------

.. warning::
Expand Down Expand Up @@ -61,6 +61,7 @@ Others:
- Moved from ``setup.cg`` to ``pyproject.toml`` configuration file
- Switched from ``flake8`` to ``ruff``
- Upgraded AutoROM to latest version
- Fixed ``stable_baselines3/dqn/*.py`` type hints
- Added ``extra_no_roms`` option for package installation without Atari Roms

Documentation:
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ exclude = """(?x)(
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
| stable_baselines3/dqn/dqn.py$
| stable_baselines3/dqn/policies.py$
| stable_baselines3/her/her_replay_buffer.py$
| stable_baselines3/ppo/ppo.py$
| stable_baselines3/sac/policies.py$
Expand All @@ -89,3 +87,18 @@ filterwarnings = [
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
]

[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = [
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
8 changes: 4 additions & 4 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class BaseAlgorithm(ABC):

# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
policy: BasePolicy

def __init__(
self,
Expand Down Expand Up @@ -123,9 +124,9 @@ def __init__(
self._vec_normalize_env = unwrap_vec_normalize(env)
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None # type: Optional[spaces.Space]
self.action_space = None # type: Optional[spaces.Space]
self.n_envs = None
self.observation_space: spaces.Space
self.action_space: spaces.Space
self.n_envs: int
self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
Expand All @@ -134,7 +135,6 @@ def __init__(
self.seed = seed
self.action_noise: Optional[ActionNoise] = None
self.start_time = None
self.policy = None
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
Expand Down
18 changes: 9 additions & 9 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
Expand All @@ -125,21 +125,15 @@ def __init__(
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
if replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer
else:
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self._episode_storage = None

# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq

self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer = None # type: Optional[ReplayBuffer]
self.replay_buffer: Optional[ReplayBuffer] = None
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
Expand Down Expand Up @@ -174,6 +168,12 @@ def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)

if self.replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer

if self.replay_buffer is None:
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
Expand Down
24 changes: 22 additions & 2 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
action_space: spaces.Space,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[nn.Module] = None,
features_extractor: Optional[BaseFeaturesExtractor] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -84,7 +84,7 @@ def __init__(

self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
self.optimizer = None # type: Optional[th.optim.Optimizer]
self.optimizer: th.optim.Optimizer

self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
Expand Down Expand Up @@ -207,6 +207,26 @@ def set_training_mode(self, mode: bool) -> None:
"""
self.train(mode)

def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
"""
Check whether or not the observation is vectorized,
apply transposition to image (so that they are channel-first) if needed.
This is used in DQN when sampling random action (epsilon-greedy policy)
:param observation: the input observation to check
:return: whether the given observation is vectorized or not
"""
vectorized_env = False
if isinstance(observation, dict):
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
else:
vectorized_env = is_vectorized_observation(
maybe_transpose(observation, self.observation_space), self.observation_space
)
return vectorized_env

def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Expand Down
16 changes: 9 additions & 7 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy

SelfDQN = TypeVar("SelfDQN", bound="DQN")
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
) -> None:
super().__init__(
policy,
env,
Expand Down Expand Up @@ -129,8 +128,9 @@ def __init__(
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule = None
self.q_net, self.q_net_target = None, None
self.exploration_schedule: Schedule
self.q_net: th.nn.Module
self.q_net_target: th.nn.Module

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -160,6 +160,8 @@ def _setup_model(self) -> None:
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

def _create_aliases(self) -> None:
# For type checker:
assert isinstance(self.policy, DQNPolicy)
self.q_net = self.policy.q_net
self.q_net_target = self.policy.q_net_target

Expand All @@ -186,7 +188,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]

with th.no_grad():
# Compute the next Q-values using the target network
Expand Down Expand Up @@ -239,7 +241,7 @@ def predict(
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
Expand Down
33 changes: 20 additions & 13 deletions stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class QNetwork(BasePolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor: nn.Module,
action_space: spaces.Discrete,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
) -> None:
super().__init__(
observation_space,
action_space,
Expand All @@ -49,9 +49,9 @@ def __init__(

self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
action_dim = self.action_space.n # number of actions
assert isinstance(self.action_space, spaces.Discrete)
action_dim = int(self.action_space.n) # number of actions
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)

Expand All @@ -62,6 +62,8 @@ def forward(self, obs: th.Tensor) -> th.Tensor:
:param obs: Observation
:return: The estimated Q-Value for each action.
"""
# For type checker:
assert isinstance(self.features_extractor, BaseFeaturesExtractor)
return self.q_net(self.extract_features(obs, self.features_extractor))

def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
Expand Down Expand Up @@ -107,7 +109,7 @@ class DQNPolicy(BasePolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
Expand All @@ -116,7 +118,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
) -> None:
super().__init__(
observation_space,
action_space,
Expand Down Expand Up @@ -144,7 +146,8 @@ def __init__(
"normalize_images": normalize_images,
}

self.q_net, self.q_net_target = None, None
self.q_net: QNetwork
self.q_net_target: QNetwork
self._build(lr_schedule)

def _build(self, lr_schedule: Schedule) -> None:
Expand All @@ -163,7 +166,11 @@ def _build(self, lr_schedule: Schedule) -> None:
self.q_net_target.set_training_mode(False)

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)

def make_q_net(self) -> QNetwork:
# Make sure we always have separate networks for features extractors etc
Expand Down Expand Up @@ -228,7 +235,7 @@ class CnnPolicy(DQNPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
Expand All @@ -237,7 +244,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
) -> None:
super().__init__(
observation_space,
action_space,
Expand Down Expand Up @@ -273,7 +280,7 @@ class MultiInputPolicy(DQNPolicy):
def __init__(
self,
observation_space: spaces.Dict,
action_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
Expand All @@ -282,7 +289,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
) -> None:
super().__init__(
observation_space,
action_space,
Expand Down

0 comments on commit aa1a64c

Please sign in to comment.