Skip to content

Commit

Permalink
Refactored ContinuousCritic for SAC/TD3 (#78)
Browse files Browse the repository at this point in the history
* Refactored ContinuousCritic for SAC/TD3

* Address comments

* Add pybullet notebook
  • Loading branch information
araffin committed Jul 6, 2020
1 parent 4aa66ed commit 3756d05
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 118 deletions.
8 changes: 7 additions & 1 deletion docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ notebooks:
- `Monitor Training and Plotting`_
- `Atari Games`_
- `RL Baselines zoo`_
- `PyBullet`_

.. - `Hindsight Experience Replay`_
Expand All @@ -27,6 +28,7 @@ notebooks:
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb

.. |colab| image:: ../_static/img/colab.svg

Expand Down Expand Up @@ -291,7 +293,7 @@ PyBullet: Normalizing input features

Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).


Expand All @@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can
you need to install pybullet with ``pip install pybullet``


.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb


.. code-block:: python
import gym
Expand Down
7 changes: 5 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
Changelog
==========

Pre-Release 0.8.0a2 (WIP)
Pre-Release 0.8.0a3 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic``
and has an additional parameter ``n_critics``

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -40,6 +42,7 @@ Documentation:
- Updated notebook links
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
- Added Unity reacher to the projects page (@koulakis)
- Added PyBullet colab notebook



Expand Down Expand Up @@ -342,4 +345,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis
@tirafesi @blurLake @koulakis
70 changes: 69 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import numpy as np

from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim
from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
NatureCNN, MlpExtractor)
from stable_baselines3.common.utils import get_device, is_vectorized_observation
Expand Down Expand Up @@ -617,6 +617,74 @@ def __init__(self,
optimizer_kwargs)


class ContinuousCritic(BasePolicy):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
:param n_critics: (int) Number of critic networks to create.
"""

def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = 'auto',
n_critics: int = 2):
super().__init__(observation_space, action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device)

action_dim = get_action_dim(self.action_space)

self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
self.add_module(f'qf{idx}', q_net)
self.q_networks.append(q_net)

def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)

def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Only predict the Q-value using the first network.
This allows to reduce computation when all the estimates are not needed
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs)
return self.q_networks[0](th.cat([features, actions], dim=1))


def create_sde_features_extractor(features_dim: int,
sde_net_arch: List[int],
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
Expand Down
69 changes: 15 additions & 54 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution

Expand Down Expand Up @@ -179,54 +179,6 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
return self.forward(observation, deterministic)


class Critic(BasePolicy):
"""
Critic network (q-value function) for SAC.
:param observation_space: (gym.spaces.Space) Obervation space
:param action_space: (gym.spaces.Space) Action space
:param net_arch: ([int]) Network architecture
:param features_extractor: (nn.Module) Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: (int) Number of features
:param activation_fn: (Type[nn.Module]) Activation function
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""

def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
device: Union[th.device, str] = 'auto'):
super(Critic, self).__init__(observation_space, action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
device=device)

action_dim = get_action_dim(self.action_space)

q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q1_net = nn.Sequential(*q1_net)

q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
self.q2_net = nn.Sequential(*q2_net)

self.q_networks = [self.q1_net, self.q2_net]

def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
# Learn the features extractor using the policy loss only
# this is much faster
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, action], dim=1)
return [q_net(qvalue_input) for q_net in self.q_networks]


class SACPolicy(BasePolicy):
"""
Policy class (with both actor and critic) for SAC.
Expand Down Expand Up @@ -255,6 +207,7 @@ class SACPolicy(BasePolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""

def __init__(self, observation_space: gym.spaces.Space,
Expand All @@ -272,7 +225,8 @@ def __init__(self, observation_space: gym.spaces.Space,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(SACPolicy, self).__init__(observation_space, action_space,
device,
features_extractor_class,
Expand Down Expand Up @@ -313,6 +267,9 @@ def __init__(self, observation_space: gym.spaces.Space,
'clip_mean': clip_mean
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({'n_critics': n_critics})

self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None

Expand Down Expand Up @@ -345,6 +302,7 @@ def _get_data(self) -> Dict[str, Any]:
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
clip_mean=self.actor_kwargs['clip_mean'],
n_critics=self.critic_kwargs['n_critics'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
Expand All @@ -364,8 +322,8 @@ def reset_noise(self, batch_size: int = 1) -> None:
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)

def make_critic(self) -> Critic:
return Critic(**self.net_args).to(self.device)
def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)

def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
Expand Down Expand Up @@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_critics: (int) Number of critic networks to create.
"""

def __init__(self, observation_space: gym.spaces.Space,
Expand All @@ -420,7 +379,8 @@ def __init__(self, observation_space: gym.spaces.Space,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
Expand All @@ -436,7 +396,8 @@ def __init__(self, observation_space: gym.spaces.Space,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs)
optimizer_kwargs,
n_critics)


register_policy("MlpPolicy", MlpPolicy)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, policy: Union[str, Type[SACPolicy]],
def _setup_model(self) -> None:
super(SAC, self)._setup_model()
self._create_aliases()

assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now"
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == 'auto':
# automatically set target entropy if needed
Expand Down

0 comments on commit 3756d05

Please sign in to comment.