Skip to content

Commit

Permalink
Fix load_from_tensor (#1231)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 22, 2022
1 parent 5549b34 commit 3c028f3
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.7.0a8 (WIP)
Release 1.7.0a9 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -39,6 +39,7 @@ Bug Fixes:
- Fixed ``Self`` return type using ``TypeVar``
- Fixed the env checker, the key was not passed when checking images from Dict observation space
- Fixed ``normalize_images`` which was not passed to parent class in some cases
- Fixed ``load_from_vector`` that was broken with newer PyTorch version when passing PyTorch tensor

Deprecations:
^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def load_from_vector(self, vector: np.ndarray) -> None:
:param vector:
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector, device=self.device), self.parameters())
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())

def parameters_to_vector(self) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a8
1.7.0a9

0 comments on commit 3c028f3

Please sign in to comment.