Skip to content

Commit

Permalink
Construct tensors directly on GPU (#1218)
Browse files Browse the repository at this point in the history
* Replace .to(device) when possible

* fix numpy dep

* black

* Add warning for device != cpu and copy=False

* Update changelog

* Remove warning

* Update buffers.py
  • Loading branch information
qgallouedec committed Dec 19, 2022
1 parent 0c1bc0b commit 68a40e0
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Others:
- Fixed ``stable_baselines3/common/atari_wrappers.py`` type hints
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device

Documentation:
^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gym==0.21", # Fixed version due to breaking changes in 0.22
"numpy",
"numpy<1.24", # Required for gym==0.21
"torch>=1.11",
'typing_extensions>=4.0,<5; python_version < "3.8.0"',
# For saving models
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data
(may be useful to avoid changing things be reference)
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
:return:
"""
if copy:
return th.tensor(array).to(self.device)
return th.as_tensor(array).to(self.device)
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)

@staticmethod
def _normalize_obs(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def load_from_vector(self, vector: np.ndarray) -> None:
:param vector:
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
th.nn.utils.vector_to_parameters(th.FloatTensor(vector, device=self.device), self.parameters())

def parameters_to_vector(self) -> np.ndarray:
"""
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ def obs_as_tensor(
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs).to(device)
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _setup_model(self) -> None:
# Force conversion to float
# this will throw an error if a malformed string (different from 'auto')
# is passed
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device)

def _create_aliases(self) -> None:
self.actor = self.policy.actor
Expand Down

0 comments on commit 68a40e0

Please sign in to comment.