Skip to content

Commit

Permalink
Fix gSDE loading issue in test mode (#45)
Browse files Browse the repository at this point in the history
* Fix gSDE loading issue in test mode

* Forward `reset_noise` method

* Re-add `make_actor`

* Reformat
  • Loading branch information
araffin committed Jun 8, 2020
1 parent 353ea81 commit 11d33eb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Pre-Release 0.7.0a0 (WIP)
Pre-Release 0.7.0a1 (WIP)
------------------------------

Breaking Changes:
Expand All @@ -18,6 +18,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed ``render()`` method for ``VecEnvs``
- Fixed ``seed()``` method for ``SubprocVecEnv``
- Fixed loading on GPU for testing when using gSDE and ``deterministic=False``

Deprecations:
^^^^^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
for name in tensors:
recursive_setattr(model, name, tensors[name])

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise()
return model

@staticmethod
Expand Down
16 changes: 12 additions & 4 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def get_std(self) -> th.Tensor:
:return: (th.Tensor)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'get_std() is only available when using gSDE'
msg = 'get_std() is only available when using gSDE'
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)

def reset_noise(self, batch_size: int = 1) -> None:
Expand All @@ -138,8 +138,8 @@ def reset_noise(self, batch_size: int = 1) -> None:
:param batch_size: (int)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'reset_noise() is only available when using gSDE'
msg = 'reset_noise() is only available when using gSDE'
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)

def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
Expand Down Expand Up @@ -354,6 +354,14 @@ def _get_data(self) -> Dict[str, Any]:
))
return data

def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int)
"""
self.actor.reset_noise(batch_size=batch_size)

def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.7.0a0
0.7.0a1

0 comments on commit 11d33eb

Please sign in to comment.