Skip to content

Commit

Permalink
Separate feature extractor networks for DQN networks (#132)
Browse files Browse the repository at this point in the history
* Separate feature extractor networks for DQN networks

* [ci skip] Bump version

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Miffyli and araffin committed Jul 30, 2020
1 parent 8f9aaae commit 77cb3dd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Pre-Release 0.8.0a5 (WIP)
Pre-Release 0.8.0a6 (WIP)
------------------------------

Breaking Changes:
Expand Down Expand Up @@ -33,6 +33,7 @@ Bug Fixes:
- Use ``cloudpickle.load`` instead of ``pickle.load`` in ``CloudpickleWrapper``. (@shwang)
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
- Fixed DQN target network sharing feature extractor with the main network.

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 4 additions & 5 deletions stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,13 @@ def __init__(
else:
net_arch = []

self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images

self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
Expand All @@ -169,7 +165,10 @@ def _build(self, lr_schedule: Callable) -> None:
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

def make_q_net(self) -> QNetwork:
return QNetwork(**self.net_args).to(self.device)
# Make sure we always have separate networks for feature extractors etc
features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
features_dim = features_extractor.features_dim
return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device)

def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0a5
0.8.0a6

0 comments on commit 77cb3dd

Please sign in to comment.