Skip to content

Commit

Permalink
Merge branch 'master' into sde
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Aug 3, 2020
2 parents af0f3a6 + cceffd5 commit b948b7f
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ``policy_kwargs`` parameter:
from stable_baselines3 import PPO
# Custom MLP policy of two layers of size 32 each with tanh activation function
# Custom MLP policy of two layers of size 32 each with Relu activation function
policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[32, 32])
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
Expand Down
10 changes: 8 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Changelog
==========

Pre-Release 0.8.0a6 (WIP)
Pre-Release 0.8.0 (2020-08-03)
------------------------------

**DQN, DDPG, bug fixes and performance matching for Atari games**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
Expand Down Expand Up @@ -34,6 +36,8 @@ Bug Fixes:
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
- Fixed DQN target network sharing feature extractor with the main network.
- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12)
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -49,6 +53,7 @@ Others:
- Ignored errors from newer pytype version
- Added a check when using ``gSDE``
- Removed codacy dependency from Dockerfile
- Added ``common.sb2_compat.RMSpropTFLike`` optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow.

Documentation:
^^^^^^^^^^^^^^
Expand All @@ -57,6 +62,7 @@ Documentation:
- Added Unity reacher to the projects page (@koulakis)
- Added PyBullet colab notebook
- Fixed typo in PPO example code (@joeljosephjin)
- Fixed typo in custom policy doc (@RaphaelWag)



Expand Down Expand Up @@ -357,4 +363,4 @@ And all the contributors:
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag
8 changes: 8 additions & 0 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3
It uses multiple workers to avoid the use of a replay buffer.


.. warning::

If you find training unstable or want to match performance of stable-baselines A2C, consider using
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.


Notes
-----

Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def train(self) -> None:
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)

# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):

actions = rollout_data.actions
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Callable]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
Expand Down Expand Up @@ -474,6 +475,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_dones = np.zeros((self.env.num_envs,), dtype=np.bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def collect_rollouts(
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs)
self._last_obs = new_obs
self._last_dones = dones

rollout_buffer.compute_returns_and_advantage(values, dones=dones)

Expand Down
Empty file.
126 changes: 126 additions & 0 deletions stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import torch
from torch.optim import Optimizer


class RMSpropTFLike(Optimizer):
r"""Implements RMSprop algorithm with closer match to Tensorflow version.
For reproducibility with original stable-baselines. Use this
version with e.g. A2C for stabler learning than with the PyTorch
RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
See a more throughout conversion in pytorch-image-models repository:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
Changes to the original RMSprop:
- Move epsilon inside square root
- Initialize squared gradient to ones rather than zeros
Proposed by G. Hinton in his
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
The implementation here takes the square root of the gradient average before
adding epsilon (note that TensorFlow interchanges these two operations). The effective
learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
is the scheduled learning rate and :math:`v` is the weighted moving average
of the squared gradient.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing constant (default: 0.99)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))

defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super(RMSpropTFLike, self).__init__(params, defaults)

def __setstate__(self, state):
super(RMSpropTFLike, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("RMSpropTF does not support sparse gradients")
state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# PyTorch initialized to zeros here
state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)

square_avg = state["square_avg"]
alpha = group["alpha"]

state["step"] += 1

if group["weight_decay"] != 0:
grad = grad.add(p, alpha=group["weight_decay"])

square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)

if group["centered"]:
grad_avg = state["grad_avg"]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
# PyTorch added epsilon after square root
# avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
else:
# PyTorch added epsilon after square root
# avg = square_avg.sqrt().add_(group['eps'])
avg = square_avg.add(group["eps"]).sqrt_()

if group["momentum"] > 0:
buf = state["momentum_buffer"]
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
p.add_(buf, alpha=-group["lr"])
else:
p.addcdiv_(grad, avg, value=-group["lr"])

return loss
2 changes: 1 addition & 1 deletion stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0a6
0.8.0
6 changes: 6 additions & 0 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch as th

from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike


@pytest.mark.parametrize(
Expand Down Expand Up @@ -32,3 +33,8 @@ def test_custom_offpolicy(model_class, net_arch):
def test_custom_optimizer(model_class, optimizer_kwargs):
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)


def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)

0 comments on commit b948b7f

Please sign in to comment.