Skip to content

Commit

Permalink
Remove "device" argument from policies (#141)
Browse files Browse the repository at this point in the history
* Remove device arg from policies

* Clean up for PR

* Update test and doc

* Fix codestyle

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
qxcv and araffin committed Aug 23, 2020
1 parent 21e9994 commit 42ef6d4
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 63 deletions.
9 changes: 9 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ To cite this project in publications:
howpublished = {\url{https://github.com/DLR-RM/stable-baselines3}},
}
Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Indices and tables
-------------------

Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Changelog
==========

Pre-Release 0.9.0a0 (WIP)
Pre-Release 0.9.0a1 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed ``device`` keyword argument of policies; use ``policy.to(device)`` instead. (@qxcv)

New Features:
^^^^^^^^^^^^^
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
self.policy_kwargs["device"] = self.device
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup

Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def _setup_model(self) -> None:
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
device=self.device,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.policy = self.policy.to(self.device)
Expand Down
31 changes: 12 additions & 19 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class BaseModel(nn.Module, ABC):
:param observation_space: (gym.spaces.Space) The observation space of the environment
:param action_space: (gym.spaces.Space) The action space of the environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
Expand All @@ -52,7 +51,6 @@ def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[th.device, str] = "auto",
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[nn.Module] = None,
Expand All @@ -70,7 +68,6 @@ def __init__(

self.observation_space = observation_space
self.action_space = action_space
self.device = get_device(device)
self.features_extractor = features_extractor
self.normalize_images = normalize_images

Expand Down Expand Up @@ -112,6 +109,16 @@ def _get_data(self) -> Dict[str, Any]:
normalize_images=self.normalize_images,
)

@property
def device(self) -> th.device:
"""Infer which device this policy lives on by inspecting its parameters.
If it has no parameters, the 'auto' device is used as a fallback.
:return: (th.device)"""
for param in self.parameters():
return param.device
return get_device("auto")

def save(self, path: str) -> None:
"""
Save model to a given location.
Expand Down Expand Up @@ -300,7 +307,6 @@ class ActorCriticPolicy(BasePolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
:param net_arch: ([int or dict]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param ortho_init: (bool) Whether to use or not orthogonal initialization
:param use_sde: (bool) Whether to use State Dependent Exploration or not
Expand Down Expand Up @@ -332,7 +338,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand All @@ -357,7 +362,6 @@ def __init__(
super(ActorCriticPolicy, self).__init__(
observation_space,
action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
Expand Down Expand Up @@ -445,9 +449,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device
)
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn)

latent_dim_pi = self.mlp_extractor.latent_dim_pi

Expand Down Expand Up @@ -594,7 +596,6 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
:param net_arch: ([int or dict]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param ortho_init: (bool) Whether to use or not orthogonal initialization
:param use_sde: (bool) Whether to use State Dependent Exploration or not
Expand Down Expand Up @@ -626,7 +627,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand All @@ -646,7 +646,6 @@ def __init__(
action_space,
lr_schedule,
net_arch,
device,
activation_fn,
ortho_init,
use_sde,
Expand Down Expand Up @@ -685,7 +684,6 @@ class ContinuousCritic(BaseModel):
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
:param n_critics: (int) Number of critic networks to create.
"""

Expand All @@ -698,15 +696,10 @@ def __init__(
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = "auto",
n_critics: int = 2,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device,
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
)

action_dim = get_action_dim(self.action_space)
Expand Down
17 changes: 2 additions & 15 deletions stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type

import gym
import torch as th
Expand All @@ -15,7 +15,6 @@ class QNetwork(BasePolicy):
:param observation_space: (gym.spaces.Space) Observation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
Expand All @@ -28,16 +27,11 @@ def __init__(
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QNetwork, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device,
observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images,
)

if net_arch is None:
Expand Down Expand Up @@ -90,7 +84,6 @@ class DQNPolicy(BasePolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
Expand All @@ -109,7 +102,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -120,7 +112,6 @@ def __init__(
super(DQNPolicy, self).__init__(
observation_space,
action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
Expand All @@ -143,7 +134,6 @@ def __init__(
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
"device": device,
}

self.q_net, self.q_net_target = None, None
Expand Down Expand Up @@ -204,7 +194,6 @@ class CnnPolicy(DQNPolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not,
Expand All @@ -221,7 +210,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -234,7 +222,6 @@ def __init__(
action_space,
lr_schedule,
net_arch,
device,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
Expand Down
12 changes: 1 addition & 11 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import gym
import torch as th
Expand Down Expand Up @@ -38,7 +38,6 @@ class Actor(BasePolicy):
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""

def __init__(
Expand All @@ -56,14 +55,12 @@ def __init__(
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
device: Union[th.device, str] = "auto",
):
super(Actor, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device,
squash_output=True,
)

Expand Down Expand Up @@ -196,7 +193,6 @@ class SACPolicy(BasePolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
Expand Down Expand Up @@ -225,7 +221,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
Expand All @@ -242,7 +237,6 @@ def __init__(
super(SACPolicy, self).__init__(
observation_space,
action_space,
device,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
Expand Down Expand Up @@ -270,7 +264,6 @@ def __init__(
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
"device": device,
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
Expand Down Expand Up @@ -356,7 +349,6 @@ class CnnPolicy(SACPolicy):
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
Expand All @@ -383,7 +375,6 @@ def __init__(
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = "auto",
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
Expand All @@ -402,7 +393,6 @@ def __init__(
action_space,
lr_schedule,
net_arch,
device,
activation_fn,
use_sde,
log_std_init,
Expand Down

0 comments on commit 42ef6d4

Please sign in to comment.