From b79d390cefeb5f7673292b2e5710faee7dce5dd2 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 11:56:44 -0400 Subject: [PATCH 1/6] Move critic to default device --- ml-agents/mlagents/trainers/poca/optimizer_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/poca/optimizer_torch.py b/ml-agents/mlagents/trainers/poca/optimizer_torch.py index 07ff16e1a2..5282d190f7 100644 --- a/ml-agents/mlagents/trainers/poca/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/poca/optimizer_torch.py @@ -4,7 +4,7 @@ ) import numpy as np import math -from mlagents.torch_utils import torch +from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import ( AgentBuffer, @@ -155,6 +155,8 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): network_settings=trainer_settings.network_settings, action_spec=policy.behavior_spec.action_spec, ) + # Move to GPU if needed + self._critic.to(default_device()) params = list(self.policy.actor.parameters()) + list(self.critic.parameters()) self.hyperparameters: POCASettings = cast( From 9c51601d3947c7fd059a75c1bdc948def9ed6bc4 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 13:40:32 -0400 Subject: [PATCH 2/6] Make sure to clone onto default device --- ml-agents/mlagents/trainers/torch/networks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index b19a383287..ab508e006f 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -1,7 +1,7 @@ from typing import Callable, List, Dict, Tuple, Optional, Union import abc -from mlagents.torch_utils import torch, nn +from mlagents.torch_utils import torch, nn, default_device from mlagents_envs.base_env import ActionSpec, ObservationSpec from mlagents.trainers.torch.action_model import ActionModel @@ -281,7 +281,8 @@ def _copy_and_remove_nans_from_obs( for i_agent, single_agent_obs in enumerate(all_obs): no_nan_obs = [] for obs in single_agent_obs: - new_obs = obs.clone() + with default_device(): + new_obs = obs.clone() new_obs[ attention_mask.type(torch.BoolTensor)[:, i_agent], :: ] = 0.0 # Remoove NaNs fast From deb059056ad7e593f93fae1ff09c0337c78f31e5 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 13:48:36 -0400 Subject: [PATCH 3/6] Add some debug stuff --- ml-agents/mlagents/trainers/torch/networks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index ab508e006f..950a3799b8 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -1,7 +1,7 @@ from typing import Callable, List, Dict, Tuple, Optional, Union import abc -from mlagents.torch_utils import torch, nn, default_device +from mlagents.torch_utils import torch, nn from mlagents_envs.base_env import ActionSpec, ObservationSpec from mlagents.trainers.torch.action_model import ActionModel @@ -281,8 +281,10 @@ def _copy_and_remove_nans_from_obs( for i_agent, single_agent_obs in enumerate(all_obs): no_nan_obs = [] for obs in single_agent_obs: - with default_device(): + # Clone to same device as obs + with torch.cuda.device_of(obs): new_obs = obs.clone() + print(obs.get_device(), new_obs.get_device()) new_obs[ attention_mask.type(torch.BoolTensor)[:, i_agent], :: ] = 0.0 # Remoove NaNs fast From 5a1987ce8e33695ffd8e034316bfdc504a0d9716 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 13:52:47 -0400 Subject: [PATCH 4/6] Some more debug --- ml-agents/mlagents/trainers/torch/networks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 950a3799b8..95714a51f9 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -269,6 +269,7 @@ def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor: ) # Get the mask from NaNs attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor) + print(attn_mask.get_device(), only_first_obs_flat.get_device()) return attn_mask def _copy_and_remove_nans_from_obs( @@ -281,10 +282,7 @@ def _copy_and_remove_nans_from_obs( for i_agent, single_agent_obs in enumerate(all_obs): no_nan_obs = [] for obs in single_agent_obs: - # Clone to same device as obs - with torch.cuda.device_of(obs): - new_obs = obs.clone() - print(obs.get_device(), new_obs.get_device()) + new_obs = obs.clone() new_obs[ attention_mask.type(torch.BoolTensor)[:, i_agent], :: ] = 0.0 # Remoove NaNs fast From dca2e415951071d22147e8a483759d1b875d0d4c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 13:58:28 -0400 Subject: [PATCH 5/6] Fix issue --- ml-agents/mlagents/trainers/torch/networks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 95714a51f9..cf90f2966d 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -268,8 +268,7 @@ def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor: [_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1 ) # Get the mask from NaNs - attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor) - print(attn_mask.get_device(), only_first_obs_flat.get_device()) + attn_mask = only_first_obs_flat.isnan().float() return attn_mask def _copy_and_remove_nans_from_obs( From b602f30a76c62c032672ad076e2ca1b4e60d4474 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 16 Mar 2021 14:58:51 -0400 Subject: [PATCH 6/6] Fix bool tensor too --- ml-agents/mlagents/trainers/torch/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index cf90f2966d..56768f8782 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -283,7 +283,7 @@ def _copy_and_remove_nans_from_obs( for obs in single_agent_obs: new_obs = obs.clone() new_obs[ - attention_mask.type(torch.BoolTensor)[:, i_agent], :: + attention_mask.bool()[:, i_agent], :: ] = 0.0 # Remoove NaNs fast no_nan_obs.append(new_obs) obs_with_no_nans.append(no_nan_obs)