Skip to content

Commit

Permalink
Optimized polyak updates (#106)
Browse files Browse the repository at this point in the history
* quick polyak updates

* changelog

* typing

* reverted autoformatting

* rerverted autofmt

* Update stable_baselines3/common/utils.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* parameter names in test

* cleanup

* Merge branch 'master' into polyak

* Update changelog

* Apply suggestions from code review

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update stable_baselines3/common/utils.py

* Update utils.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
m-rph and araffin committed Jul 17, 2020
1 parent 23afedb commit dbe8cfc
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 14 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Pre-Release 0.8.0a4 (WIP)
Pre-Release 0.8.0a5 (WIP)
------------------------------

Breaking Changes:
Expand Down Expand Up @@ -40,6 +40,7 @@ Others:
- Split the ``collect_rollout()`` method for off-policy algorithms
- Added ``_on_step()`` for off-policy base class
- Optimized replay buffer size by removing the need of ``next_observations`` numpy array
- Optimized polyak updates (1.5-1.95 speedup) through inplace operations (@PartiallyTyped)
- Switch to ``black`` codestyle and added ``make format``, ``make check-codestyle`` and ``commit-checks``
- Ignored errors from newer pytype version
- Added a check when using ``gSDE``
Expand Down
24 changes: 23 additions & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import random
from collections import deque
from typing import Callable, Optional, Union
from typing import Callable, Iterable, Optional, Union

import gym
import numpy as np
Expand Down Expand Up @@ -284,3 +284,25 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)


def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: (Iterable[th.nn.Parameter]) parameters to use to update the target params
:param target_params: (Iterable[th.nn.Parameter]) parameters to update
:param tau: (float) the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
for param, target_param in zip(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
5 changes: 2 additions & 3 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.utils import get_linear_fn, polyak_update
from stable_baselines3.dqn.policies import DQNPolicy


Expand Down Expand Up @@ -138,8 +138,7 @@ def _on_step(self):
This method is called in ``collect_rollout()`` after each step in the environment.
"""
if self.num_timesteps % self.target_update_interval == 0:
for param, target_param in zip(self.q_net.parameters(), self.q_net_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
logger.record("rollout/exploration rate", self.exploration_rate)
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy


Expand Down Expand Up @@ -256,8 +257,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:

# Update target networks
if gradient_step % self.target_update_interval == 0:
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

self._n_updates += gradient_steps

Expand Down
9 changes: 3 additions & 6 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.td3.policies import TD3Policy


Expand Down Expand Up @@ -166,12 +167,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
actor_loss.backward()
self.actor.optimizer.step()

# Update the frozen target networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)

self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0a4
0.8.0a5
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv


Expand Down Expand Up @@ -152,3 +154,16 @@ def test_vec_noise():
vec.noises = [base] * (num_envs - 1)
assert all(isinstance(noise, type(base)) for noise in vec.noises)
assert len(vec.noises) == num_envs


def test_polyak():
param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
tau = 0.1
polyak_update([param1], [param2], tau)
with th.no_grad():
for param, target_param in zip([target1], [target2]):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

assert th.allclose(param1, target1)
assert th.allclose(param2, target2)

0 comments on commit dbe8cfc

Please sign in to comment.