diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 8f3404927e..4cba4c9a2b 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -1,16 +1,15 @@ from typing import Dict, Optional, Tuple, List import torch import numpy as np -from mlagents_envs.base_env import DecisionSteps from mlagents.trainers.buffer import AgentBuffer +from mlagents.trainers.trajectory import SplitObservations from mlagents.trainers.torch.components.bc.module import BCModule from mlagents.trainers.torch.components.reward_providers import create_reward_provider from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer import Optimizer from mlagents.trainers.settings import TrainerSettings -from mlagents.trainers.trajectory import SplitObservations from mlagents.trainers.torch.utils import ModelUtils @@ -50,35 +49,6 @@ def create_reward_signals(self, reward_signal_configs): reward_signal, self.policy.behavior_spec, settings ) - def get_value_estimates( - self, decision_requests: DecisionSteps, idx: int, done: bool - ) -> Dict[str, float]: - """ - Generates value estimates for bootstrapping. - :param decision_requests: - :param idx: Index in BrainInfo of agent. - :param done: Whether or not this is the last element of the episode, - in which case the value estimate will be 0. - :return: The value estimate dictionary with key being the name of the reward signal - and the value the corresponding value estimate. - """ - vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) - - value_estimates = self.policy.actor_critic.critic_pass( - np.expand_dims(vec_vis_obs.vector_observations[idx], 0), - np.expand_dims(vec_vis_obs.visual_observations[idx], 0), - ) - - value_estimates = {k: float(v) for k, v in value_estimates.items()} - - # If we're done, reassign all of the value estimates that need terminal states. - if done: - for k in value_estimates: - if not self.reward_signals[k].ignore_done: - value_estimates[k] = 0.0 - - return value_estimates - def get_trajectory_value_estimates( self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool ) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: @@ -93,18 +63,23 @@ def get_trajectory_value_estimates( else: visual_obs = [] - memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size]) + memory = torch.zeros([1, 1, self.policy.m_size]) - next_obs = np.concatenate(next_obs, axis=-1) - next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)] - next_memory = torch.zeros([1, 1, self.policy.m_size]) + vec_vis_obs = SplitObservations.from_observations(next_obs) + next_vec_obs = [ + ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) + ] + next_vis_obs = [ + ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) + for _vis_ob in vec_vis_obs.visual_observations + ] - value_estimates = self.policy.actor_critic.critic_pass( - vector_obs, visual_obs, memory + value_estimates, next_memory = self.policy.actor_critic.critic_pass( + vector_obs, visual_obs, memory, sequence_length=batch.num_experiences ) - next_value_estimate = self.policy.actor_critic.critic_pass( - next_obs, next_obs, next_memory + next_value_estimate, _ = self.policy.actor_critic.critic_pass( + next_vec_obs, next_vis_obs, next_memory, sequence_length=1 ) for name, estimate in value_estimates.items(): diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index d6dd822646..9d935d4676 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -186,7 +186,7 @@ def evaluate( run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) run_out["learning_rate"] = 0.0 if self.use_recurrent: - run_out["memories"] = memories.detach().cpu().numpy() + run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0) return run_out def get_action( diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 9bbb7b51c8..e162166481 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -61,12 +61,15 @@ def ppo_value_loss( old_values: Dict[str, torch.Tensor], returns: Dict[str, torch.Tensor], epsilon: float, + loss_masks: torch.Tensor, ) -> torch.Tensor: """ - Creates training-specific Tensorflow ops for PPO models. - :param returns: - :param old_values: - :param values: + Evaluates value loss for PPO. + :param values: Value output of the current network. + :param old_values: Value stored with experiences in buffer. + :param returns: Computed returns. + :param epsilon: Clipping value for value estimate. + :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ value_losses = [] for name, head in values.items(): @@ -77,18 +80,24 @@ def ppo_value_loss( ) v_opt_a = (returns_tensor - head) ** 2 v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 - value_loss = torch.mean(torch.max(v_opt_a, v_opt_b)) + value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss - def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks): + def ppo_policy_loss( + self, + advantages: torch.Tensor, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + loss_masks: torch.Tensor, + ) -> torch.Tensor: """ - Creates training-specific Tensorflow ops for PPO models. - :param masks: - :param advantages: + Evaluate PPO policy loss. + :param advantages: Computed advantages. :param log_probs: Current policy probabilities :param old_log_probs: Past policy probabilities + :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ advantage = advantages.unsqueeze(-1) @@ -99,7 +108,9 @@ def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks): p_opt_b = ( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage ) - policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b)) + policy_loss = -1 * ModelUtils.masked_mean( + torch.min(p_opt_a, p_opt_b).flatten(), loss_masks + ) return policy_loss @timed @@ -153,14 +164,21 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps) + loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) + value_loss = self.ppo_value_loss( + values, old_values, returns, decay_eps, loss_masks + ) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), - ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), + loss_masks, + ) + loss = ( + policy_loss + + 0.5 * value_loss + - decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks) ) - loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 9c3ced80a7..2a26d715f5 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -1,7 +1,8 @@ import numpy as np -from typing import Dict, List, Mapping, cast, Tuple +from typing import Dict, List, Mapping, cast, Tuple, Optional import torch from torch import nn +import attr from mlagents_envs.logging_util import get_logger from mlagents_envs.base_env import ActionType @@ -56,10 +57,24 @@ def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], - actions: torch.Tensor = None, + actions: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: - q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions) - q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions) + q1_out, _ = self.q1_network( + vec_inputs, + vis_inputs, + actions=actions, + memories=memories, + sequence_length=sequence_length, + ) + q2_out, _ = self.q2_network( + vec_inputs, + vis_inputs, + actions=actions, + memories=memories, + sequence_length=sequence_length, + ) return q1_out, q2_out def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): @@ -87,17 +102,28 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): for name in self.stream_names } + # Critics should have 1/2 of the memory of the policy + critic_memory = policy_network_settings.memory + if critic_memory is not None: + critic_memory = attr.evolve( + critic_memory, memory_size=critic_memory.memory_size // 2 + ) + value_network_settings = attr.evolve( + policy_network_settings, memory=critic_memory + ) + self.value_network = TorchSACOptimizer.PolicyValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, - policy_network_settings, + value_network_settings, self.policy.behavior_spec.action_type, self.act_size, ) + self.target_network = ValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, - policy_network_settings, + value_network_settings, ) self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) @@ -168,11 +194,11 @@ def sac_q_loss( * self.gammas[i] * target_values[name] ) - _q1_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream) + _q1_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks ) - _q2_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream) + _q2_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks ) q1_losses.append(_q1_loss) @@ -232,9 +258,8 @@ def sac_value_loss( v_backup = min_policy_qs[name] - torch.sum( _ent_coef * log_probs, dim=1 ) - # print(log_probs, v_backup, _ent_coef, loss_masks) - value_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) + value_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(values[name], v_backup), loss_masks ) value_losses.append(value_loss) else: @@ -253,9 +278,9 @@ def sac_value_loss( v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0 ) - value_loss = 0.5 * torch.mean( - loss_masks - * torch.nn.functional.mse_loss(values[name], v_backup.squeeze()) + value_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), + loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) @@ -275,7 +300,7 @@ def sac_policy_loss( if not discrete: mean_q1 = mean_q1.unsqueeze(1) batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) - policy_loss = torch.mean(loss_masks * batch_policy_loss) + policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) else: action_probs = log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( @@ -322,9 +347,8 @@ def sac_entropy_loss( target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) - entropy_loss = -torch.mean( - loss_masks - * torch.mean(self._log_ent_coef * target_current_diff, axis=1) + entropy_loss = -1 * ModelUtils.masked_mean( + torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks ) return entropy_loss @@ -369,12 +393,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) - memories = [ + memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(0, len(batch["memory"]), self.policy.sequence_length) ] - if len(memories) > 0: - memories = torch.stack(memories).unsqueeze(0) + # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. + offset = 1 if self.policy.sequence_length > 1 else 0 + next_memories_list = [ + ModelUtils.list_to_tensor( + batch["memory"][i][self.policy.m_size // 2 :] + ) # only pass value part of memory to target network + for i in range(offset, len(batch["memory"]), self.policy.sequence_length) + ] + + if len(memories_list) > 0: + memories = torch.stack(memories_list).unsqueeze(0) + next_memories = torch.stack(next_memories_list).unsqueeze(0) + else: + memories = None + next_memories = None + # Q network memories are 0'ed out, since we don't have them during inference. + q_memories = torch.zeros_like(next_memories) + vis_obs: List[torch.Tensor] = [] next_vis_obs: List[torch.Tensor] = [] if self.policy.use_vis_obs: @@ -415,19 +455,46 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: ) if self.policy.use_continuous_act: squeezed_actions = actions.squeeze(-1) - q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions) - q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions) + q1p_out, q2p_out = self.value_network( + vec_obs, + vis_obs, + sampled_actions, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) + q1_out, q2_out = self.value_network( + vec_obs, + vis_obs, + squeezed_actions, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) q1_stream, q2_stream = q1_out, q2_out else: with torch.no_grad(): - q1p_out, q2p_out = self.value_network(vec_obs, vis_obs) - q1_out, q2_out = self.value_network(vec_obs, vis_obs) + q1p_out, q2p_out = self.value_network( + vec_obs, + vis_obs, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) + q1_out, q2_out = self.value_network( + vec_obs, + vis_obs, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) q1_stream = self._condense_q_streams(q1_out, actions) q2_stream = self._condense_q_streams(q2_out, actions) with torch.no_grad(): - target_values, _ = self.target_network(next_vec_obs, next_vis_obs) - masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) + target_values, _ = self.target_network( + next_vec_obs, + next_vis_obs, + memories=next_memories, + sequence_length=self.policy.sequence_length, + ) + masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) use_discrete = not self.policy.use_continuous_act dones = ModelUtils.list_to_tensor(batch["done"]) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py index 499d0de285..6d1132aa2e 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_layers.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -1,6 +1,11 @@ import torch -from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization +from mlagents.trainers.torch.layers import ( + Swish, + linear_layer, + lstm_layer, + Initialization, +) def test_swish(): @@ -18,3 +23,18 @@ def test_initialization_layer(): ) assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) + + +def test_lstm_layer(): + torch.manual_seed(0) + # Test zero for LSTM + layer = lstm_layer( + 4, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero + ) + for name, param in layer.named_parameters(): + if "weight" in name: + assert torch.all(torch.eq(param.data, torch.zeros_like(param.data))) + elif "bias" in name: + assert torch.all( + torch.eq(param.data[4:8], torch.ones_like(param.data[4:8])) + ) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index 5d3cf9a9e5..06f8b1ab25 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -45,16 +45,16 @@ def test_networkbody_lstm(): obs_size = 4 seq_len = 16 network_settings = NetworkSettings( - memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4) + memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12) ) obs_shapes = [(obs_size,)] networkbody = NetworkBody(obs_shapes, network_settings) - optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) + optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) sample_obs = torch.ones((1, seq_len, obs_size)) - for _ in range(100): - encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4)) + for _ in range(200): + encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12)) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() @@ -196,15 +196,20 @@ def test_actor_critic(ac_type, lstm): # memories isn't always set to None, the network should be able to # deal with that. # Test critic pass - value_out = actor.critic_pass([sample_obs], [], memories=memories) + value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories) for stream in stream_names: if lstm: assert value_out[stream].shape == (network_settings.memory.sequence_length,) + assert memories_out.shape == memories.shape else: assert value_out[stream].shape == (1,) # Test get_dist_and_value - dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories) + dists, value_out, mem_out = actor.get_dist_and_value( + [sample_obs], [], memories=memories + ) + if mem_out is not None: + assert mem_out.shape == memories.shape for dist in dists: assert isinstance(dist, GaussianDistInstance) for stream in stream_names: diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index 70306a89f3..0275581d08 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -198,3 +198,19 @@ def test_get_probs_and_entropy(): assert entropies.shape == (1, len(dist_list)) # Make sure the first action has high probability than the others. assert log_probs.flatten()[0] > log_probs.flatten()[1] + + +def test_masked_mean(): + test_input = torch.tensor([1, 2, 3, 4, 5]) + masks = torch.ones_like(test_input).bool() + mean = ModelUtils.masked_mean(test_input, masks=masks) + assert mean == 3.0 + + masks = torch.tensor([False, False, True, True, True]) + mean = ModelUtils.masked_mean(test_input, masks=masks) + assert mean == 4.0 + + # Make sure it works if all masks are off + masks = torch.tensor([False, False, False, False, False]) + mean = ModelUtils.masked_mean(test_input, masks=masks) + assert mean == 0.0 diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 8dbb1cbcb4..707d4748a5 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -46,3 +46,39 @@ def linear_layer( layer.weight.data *= kernel_gain _init_methods[bias_init](layer.bias.data) return layer + + +def lstm_layer( + input_size: int, + hidden_size: int, + num_layers: int = 1, + batch_first: bool = True, + forget_bias: float = 1.0, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + bias_init: Initialization = Initialization.Zero, +) -> torch.nn.Module: + """ + Creates a torch.nn.LSTM and initializes its weights and biases. Provides a + forget_bias offset like is done in TensorFlow. + """ + lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first) + # Add forget_bias to forget gate bias + for name, param in lstm.named_parameters(): + # Each weight and bias is a concatenation of 4 matrices + if "weight" in name: + for idx in range(4): + block_size = param.shape[0] // 4 + _init_methods[kernel_init]( + param.data[idx * block_size : (idx + 1) * block_size] + ) + if "bias" in name: + for idx in range(4): + block_size = param.shape[0] // 4 + _init_methods[bias_init]( + param.data[idx * block_size : (idx + 1) * block_size] + ) + if idx == 1: + param.data[idx * block_size : (idx + 1) * block_size].add_( + forget_bias + ) + return lstm diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 15f0d92e2d..bfe9c0ade3 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,6 +14,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads +from mlagents.trainers.torch.layers import lstm_layer ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -50,7 +51,7 @@ def __init__( ) if self.use_lstm: - self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) + self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) else: self.lstm = None @@ -101,13 +102,11 @@ def forward( raise Exception("No valid inputs to network.") if self.use_lstm: - encoding = encoding.view([sequence_length, -1, self.h_size]) + # Resize to (batch, sequence length, encoding size) + encoding = encoding.reshape([-1, sequence_length, self.h_size]) memories = torch.split(memories, self.m_size // 2, dim=-1) - encoding, memories = self.lstm( - encoding.contiguous(), - (memories[0].contiguous(), memories[1].contiguous()), - ) - encoding = encoding.view([-1, self.m_size // 2]) + encoding, memories = self.lstm(encoding, memories) + encoding = encoding.reshape([-1, self.m_size // 2]) memories = torch.cat(memories, dim=-1) return encoding, memories @@ -210,7 +209,8 @@ def critic_pass( vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], memories: Optional[torch.Tensor] = None, - ) -> Dict[str, torch.Tensor]: + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: """ Get value outputs for the given obs. :param vec_inputs: List of vector inputs as tensors. @@ -360,9 +360,12 @@ def critic_pass( vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], memories: Optional[torch.Tensor] = None, - ) -> Dict[str, torch.Tensor]: - encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) - return self.value_heads(encoding) + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + encoding, memories_out = self.network_body( + vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length + ) + return self.value_heads(encoding), memories_out def get_dist_and_value( self, @@ -427,16 +430,21 @@ def critic_pass( vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], memories: Optional[torch.Tensor] = None, - ) -> Dict[str, torch.Tensor]: + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + actor_mem, critic_mem = None, None if self.use_lstm: # Use only the back half of memories for critic - _, critic_mem = torch.split(memories, self.half_mem_size, -1) - else: - critic_mem = None - value_outputs, _memories = self.critic( - vec_inputs, vis_inputs, memories=critic_mem + actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) + value_outputs, critic_mem_out = self.critic( + vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) - return value_outputs + if actor_mem is not None: + # Make memories with the actor mem unchanged + memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) + else: + memories_out = None + return value_outputs, memories_out def get_dist_and_value( self, @@ -463,7 +471,7 @@ def get_dist_and_value( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) if self.use_lstm: - mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1) + mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return dists, value_outputs, mem_out diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index baa99c98ef..0e855ea79b 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -284,3 +284,13 @@ def get_probs_and_entropy( else: all_probs = torch.cat(all_probs_list, dim=-1) return log_probs, entropies, all_probs + + @staticmethod + def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: + """ + Returns the mean of the tensor but ignoring the values specified by masks. + Used for masking out loss functions. + :param tensor: Tensor which needs mean computation. + :param masks: Boolean tensor of masks with same dimension as tensor. + """ + return (tensor * masks).sum() / torch.clamp(masks.float().sum(), min=1.0)