Skip to content

Commit

Permalink
Fix DQN target update interval for multi-env (#1463)
Browse files Browse the repository at this point in the history
* Calculating target update interval per environment in `_on_step()`. See GitHub issue #1373

* Added changelog entry and changed test comment

* Added requested changes from code review

* Update version

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
tobirohrer and araffin committed Apr 27, 2023
1 parent dc09d81 commit 6cbb2c9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a5 (WIP)
Release 2.0.0a6 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -35,6 +35,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 4 additions & 5 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def _setup_model(self) -> None:
self.exploration_final_eps,
self.exploration_fraction,
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions

if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
Expand All @@ -162,8 +161,6 @@ def _setup_model(self) -> None:
f"which corresponds to {self.n_envs} steps."
)

self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

def _create_aliases(self) -> None:
self.q_net = self.policy.q_net
self.q_net_target = self.policy.q_net_target
Expand All @@ -174,7 +171,9 @@ def _on_step(self) -> None:
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a5
2.0.0a6
12 changes: 12 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
Expand Down Expand Up @@ -730,3 +731,14 @@ def test_load_invalid_object(tmp_path):
with warnings.catch_warnings(record=True) as record:
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
assert len(record) == 0


def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #1373.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100)
model.save(tmp_path / "dqn_cartpole")
del model
model = DQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100

0 comments on commit 6cbb2c9

Please sign in to comment.