diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 9ee3845515..40663b38bd 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -1,5 +1,5 @@ from typing import Dict, Optional, Tuple, List -from mlagents.torch_utils import torch +from mlagents.torch_utils import torch, default_device import numpy as np from collections import defaultdict @@ -162,7 +162,7 @@ def get_trajectory_value_estimates( memory = self.critic_memory_dict[agent_id] else: memory = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) diff --git a/ml-agents/mlagents/trainers/poca/optimizer_torch.py b/ml-agents/mlagents/trainers/poca/optimizer_torch.py index de17f3d3b2..9f5ecf11d3 100644 --- a/ml-agents/mlagents/trainers/poca/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/poca/optimizer_torch.py @@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates( _init_baseline_mem = self.baseline_memory_dict[agent_id] else: _init_value_mem = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) _init_baseline_mem = ( - torch.zeros((1, 1, self.critic.memory_size)) + torch.zeros((1, 1, self.critic.memory_size), device=default_device()) if self.policy.use_recurrent else None ) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index fceacda6e9..f7cdb75c4e 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -69,13 +69,17 @@ def export_memory_size(self) -> int: return self._export_m_size def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray: + device = default_device() mask = None if self.behavior_spec.action_spec.discrete_size > 0: num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches) - mask = torch.ones([len(decision_requests), num_discrete_flat]) + mask = torch.ones( + [len(decision_requests), num_discrete_flat], device=device + ) if decision_requests.action_mask is not None: mask = torch.as_tensor( - 1 - np.concatenate(decision_requests.action_mask, axis=1) + 1 - np.concatenate(decision_requests.action_mask, axis=1), + device=device, ) return mask @@ -91,11 +95,12 @@ def evaluate( """ obs = decision_requests.obs masks = self._extract_masks(decision_requests) - tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs] + device = default_device() + tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs] - memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( - 0 - ) + memories = torch.as_tensor( + self.retrieve_memories(global_agent_ids), device=device + ).unsqueeze(0) with torch.no_grad(): action, run_out, memories = self.actor.get_action_and_stats( tensor_obs, masks=masks, memories=memories diff --git a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py index 0ae77ba143..906f9e32c1 100644 --- a/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py @@ -143,7 +143,7 @@ def compute_estimate( if self._settings.use_actions: actions = self.get_action_input(mini_batch) dones = torch.as_tensor( - mini_batch[BufferKey.DONE], dtype=torch.float + mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) action_inputs = torch.cat([actions, dones], dim=1) hidden, _ = self.encoder(inputs, action_inputs) @@ -162,7 +162,7 @@ def compute_loss( """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ - total_loss = torch.zeros(1) + total_loss = torch.zeros(1, device=default_device()) stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate( policy_batch, use_vail_noise=True @@ -219,21 +219,23 @@ def compute_gradient_magnitude( expert_inputs = self.get_state_inputs(expert_batch) interp_inputs = [] for policy_input, expert_input in zip(policy_inputs, expert_inputs): - obs_epsilon = torch.rand(policy_input.shape) + obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device) interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input interp_input.requires_grad = True # For gradient calculation interp_inputs.append(interp_input) if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(expert_batch) - action_epsilon = torch.rand(policy_action.shape) + action_epsilon = torch.rand( + policy_action.shape, device=policy_action.device + ) policy_dones = torch.as_tensor( - policy_batch[BufferKey.DONE], dtype=torch.float + policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) expert_dones = torch.as_tensor( - expert_batch[BufferKey.DONE], dtype=torch.float + expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device() ).unsqueeze(1) - dones_epsilon = torch.rand(policy_dones.shape) + dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device) action_inputs = torch.cat( [ action_epsilon * policy_action diff --git a/ml-agents/mlagents/trainers/torch_entities/networks.py b/ml-agents/mlagents/trainers/torch_entities/networks.py index 555268075c..196d23698f 100644 --- a/ml-agents/mlagents/trainers/torch_entities/networks.py +++ b/ml-agents/mlagents/trainers/torch_entities/networks.py @@ -1,7 +1,7 @@ from typing import Callable, List, Dict, Tuple, Optional, Union, Any import abc -from mlagents.torch_utils import torch, nn +from mlagents.torch_utils import torch, nn, default_device from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType from mlagents.trainers.torch_entities.action_model import ActionModel @@ -87,7 +87,9 @@ def update_normalization(self, buffer: AgentBuffer) -> None: obs = ObsUtil.from_buffer(buffer, len(self.processors)) for vec_input, enc in zip(obs, self.processors): if isinstance(enc, VectorInput): - enc.update_normalization(torch.as_tensor(vec_input.to_ndarray())) + enc.update_normalization( + torch.as_tensor(vec_input.to_ndarray(), device=default_device()) + ) def copy_normalization(self, other_encoder: "ObservationEncoder") -> None: if self.normalize: diff --git a/ml-agents/mlagents/trainers/torch_entities/utils.py b/ml-agents/mlagents/trainers/torch_entities/utils.py index d5381cbecb..7f2cea40ab 100644 --- a/ml-agents/mlagents/trainers/torch_entities/utils.py +++ b/ml-agents/mlagents/trainers/torch_entities/utils.py @@ -1,5 +1,5 @@ from typing import List, Optional, Tuple, Dict -from mlagents.torch_utils import torch, nn +from mlagents.torch_utils import torch, nn, default_device from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization import numpy as np @@ -233,7 +233,8 @@ def list_to_tensor( Converts a list of numpy arrays into a tensor. MUCH faster than calling as_tensor on the list directly. """ - return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype) + device = default_device() + return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device) @staticmethod def list_to_tensor_list( @@ -243,8 +244,10 @@ def list_to_tensor_list( Converts a list of numpy arrays into a list of tensors. MUCH faster than calling as_tensor on the list directly. """ + device = default_device() return [ - torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list + torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device) + for _arr in ndarray_list ] @staticmethod