Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
from mlagents.torch_utils import torch, default_device
import numpy as np
from collections import defaultdict

Expand Down Expand Up @@ -162,7 +162,7 @@ def get_trajectory_value_estimates(
memory = self.critic_memory_dict[agent_id]
else:
memory = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/poca/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates(
_init_baseline_mem = self.baseline_memory_dict[agent_id]
else:
_init_value_mem = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
_init_baseline_mem = (
torch.zeros((1, 1, self.critic.memory_size))
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
if self.policy.use_recurrent
else None
)
Expand Down
17 changes: 11 additions & 6 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,17 @@ def export_memory_size(self) -> int:
return self._export_m_size

def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
device = default_device()
mask = None
if self.behavior_spec.action_spec.discrete_size > 0:
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
mask = torch.ones([len(decision_requests), num_discrete_flat])
mask = torch.ones(
[len(decision_requests), num_discrete_flat], device=device
)
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)
1 - np.concatenate(decision_requests.action_mask, axis=1),
device=device,
)
return mask

Expand All @@ -91,11 +95,12 @@ def evaluate(
"""
obs = decision_requests.obs
masks = self._extract_masks(decision_requests)
tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs]
device = default_device()
tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs]

memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
0
)
memories = torch.as_tensor(
self.retrieve_memories(global_agent_ids), device=device
).unsqueeze(0)
with torch.no_grad():
action, run_out, memories = self.actor.get_action_and_stats(
tensor_obs, masks=masks, memories=memories
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def compute_estimate(
if self._settings.use_actions:
actions = self.get_action_input(mini_batch)
dones = torch.as_tensor(
mini_batch[BufferKey.DONE], dtype=torch.float
mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
action_inputs = torch.cat([actions, dones], dim=1)
hidden, _ = self.encoder(inputs, action_inputs)
Expand All @@ -162,7 +162,7 @@ def compute_loss(
"""
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
"""
total_loss = torch.zeros(1)
total_loss = torch.zeros(1, device=default_device())
stats_dict: Dict[str, np.ndarray] = {}
policy_estimate, policy_mu = self.compute_estimate(
policy_batch, use_vail_noise=True
Expand Down Expand Up @@ -219,21 +219,23 @@ def compute_gradient_magnitude(
expert_inputs = self.get_state_inputs(expert_batch)
interp_inputs = []
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
obs_epsilon = torch.rand(policy_input.shape)
obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device)
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
interp_input.requires_grad = True # For gradient calculation
interp_inputs.append(interp_input)
if self._settings.use_actions:
policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(expert_batch)
action_epsilon = torch.rand(policy_action.shape)
action_epsilon = torch.rand(
policy_action.shape, device=policy_action.device
)
policy_dones = torch.as_tensor(
policy_batch[BufferKey.DONE], dtype=torch.float
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
expert_dones = torch.as_tensor(
expert_batch[BufferKey.DONE], dtype=torch.float
expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
).unsqueeze(1)
dones_epsilon = torch.rand(policy_dones.shape)
dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device)
action_inputs = torch.cat(
[
action_epsilon * policy_action
Expand Down
6 changes: 4 additions & 2 deletions ml-agents/mlagents/trainers/torch_entities/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, List, Dict, Tuple, Optional, Union, Any
import abc

from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn, default_device

from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
from mlagents.trainers.torch_entities.action_model import ActionModel
Expand Down Expand Up @@ -87,7 +87,9 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input.to_ndarray()))
enc.update_normalization(
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
)

def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
if self.normalize:
Expand Down
9 changes: 6 additions & 3 deletions ml-agents/mlagents/trainers/torch_entities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.torch_utils import torch, nn, default_device
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np

Expand Down Expand Up @@ -233,7 +233,8 @@ def list_to_tensor(
Converts a list of numpy arrays into a tensor. MUCH faster than
calling as_tensor on the list directly.
"""
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
device = default_device()
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device)

@staticmethod
def list_to_tensor_list(
Expand All @@ -243,8 +244,10 @@ def list_to_tensor_list(
Converts a list of numpy arrays into a list of tensors. MUCH faster than
calling as_tensor on the list directly.
"""
device = default_device()
return [
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device)
for _arr in ndarray_list
]

@staticmethod
Expand Down