From d523c8cc0ac0214963eee5323f717c67c647e540 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 7 Aug 2020 16:14:21 -0700 Subject: [PATCH 01/27] Running LSTM for SAC --- .../mlagents/trainers/sac/optimizer_torch.py | 100 +++++++++++++++--- 1 file changed, 85 insertions(+), 15 deletions(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 9c3ced80a7..84df073fc0 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -1,7 +1,8 @@ import numpy as np -from typing import Dict, List, Mapping, cast, Tuple +from typing import Dict, List, Mapping, cast, Tuple, Optional import torch from torch import nn +import attr from mlagents_envs.logging_util import get_logger from mlagents_envs.base_env import ActionType @@ -56,10 +57,24 @@ def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], - actions: torch.Tensor = None, + actions: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: - q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions) - q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions) + q1_out, _ = self.q1_network( + vec_inputs, + vis_inputs, + actions=actions, + memories=memories, + sequence_length=sequence_length, + ) + q2_out, _ = self.q2_network( + vec_inputs, + vis_inputs, + actions=actions, + memories=memories, + sequence_length=sequence_length, + ) return q1_out, q2_out def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): @@ -87,17 +102,28 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): for name in self.stream_names } + # Critics should have 1/2 of the memory of the policy + critic_memory = policy_network_settings.memory + if critic_memory is not None: + critic_memory = attr.evolve( + critic_memory, memory_size=critic_memory.memory_size // 2 + ) + value_network_settings = attr.evolve( + policy_network_settings, memory=critic_memory + ) + self.value_network = TorchSACOptimizer.PolicyValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, - policy_network_settings, + value_network_settings, self.policy.behavior_spec.action_type, self.act_size, ) + self.target_network = ValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, - policy_network_settings, + value_network_settings, ) self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) @@ -232,7 +258,6 @@ def sac_value_loss( v_backup = min_policy_qs[name] - torch.sum( _ent_coef * log_probs, dim=1 ) - # print(log_probs, v_backup, _ent_coef, loss_masks) value_loss = 0.5 * torch.mean( loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) ) @@ -369,12 +394,30 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) - memories = [ + memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(0, len(batch["memory"]), self.policy.sequence_length) ] - if len(memories) > 0: - memories = torch.stack(memories).unsqueeze(0) + # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. + offset = 1 if self.policy.sequence_length > 1 else 0 + next_memories_list = [ + ModelUtils.list_to_tensor( + batch["memory"][i][: self.policy.m_size // 2] + ) # only pass value part of memory to target network + for i in range(offset, len(batch["memory"]), self.policy.sequence_length) + ] + + if len(memories_list) > 0: + memories = torch.stack(memories_list).unsqueeze(0) + next_memories = torch.stack(next_memories_list).unsqueeze(0) + else: + memories = None + next_memories = None + # Q network memories are 0'ed out, since we don't have them during inference. + q_memories = torch.zeros( + (memories.shape[0], memories.shape[1], memories.shape[2] // 2) + ) + vis_obs: List[torch.Tensor] = [] next_vis_obs: List[torch.Tensor] = [] if self.policy.use_vis_obs: @@ -415,18 +458,45 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: ) if self.policy.use_continuous_act: squeezed_actions = actions.squeeze(-1) - q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions) - q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions) + q1p_out, q2p_out = self.value_network( + vec_obs, + vis_obs, + sampled_actions, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) + q1_out, q2_out = self.value_network( + vec_obs, + vis_obs, + squeezed_actions, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) q1_stream, q2_stream = q1_out, q2_out else: with torch.no_grad(): - q1p_out, q2p_out = self.value_network(vec_obs, vis_obs) - q1_out, q2_out = self.value_network(vec_obs, vis_obs) + q1p_out, q2p_out = self.value_network( + vec_obs, + vis_obs, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) + q1_out, q2_out = self.value_network( + vec_obs, + vis_obs, + memories=q_memories, + sequence_length=self.policy.sequence_length, + ) q1_stream = self._condense_q_streams(q1_out, actions) q2_stream = self._condense_q_streams(q2_out, actions) with torch.no_grad(): - target_values, _ = self.target_network(next_vec_obs, next_vis_obs) + target_values, _ = self.target_network( + next_vec_obs, + next_vis_obs, + memories=next_memories, + sequence_length=self.policy.sequence_length, + ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) use_discrete = not self.policy.use_continuous_act dones = ModelUtils.list_to_tensor(batch["done"]) From f2873b296fe60fefb86545405c6076e26d10af5e Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 10:38:19 -0700 Subject: [PATCH 02/27] Use correct half of memories --- ml-agents/mlagents/trainers/sac/optimizer_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 84df073fc0..3ba92e6c1b 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -402,7 +402,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( - batch["memory"][i][: self.policy.m_size // 2] + batch["memory"][i][self.policy.m_size // 2 :] ) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] From b97b1e535af6006bd864249693c78460bc721877 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 17:40:47 -0700 Subject: [PATCH 03/27] Fix policy memory storinig --- ml-agents/mlagents/trainers/policy/torch_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index d6dd822646..9d935d4676 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -186,7 +186,7 @@ def evaluate( run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) run_out["learning_rate"] = 0.0 if self.use_recurrent: - run_out["memories"] = memories.detach().cpu().numpy() + run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0) return run_out def get_action( From cd509ddbbdc3e4025040105f9d521034a37fdf3a Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 18:12:27 -0700 Subject: [PATCH 04/27] Fix SeparateActorCritic and add test --- ml-agents/mlagents/trainers/tests/torch/test_networks.py | 6 +++++- ml-agents/mlagents/trainers/torch/networks.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index ff5209b676..19030aeafc 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -203,7 +203,11 @@ def test_actor_critic(ac_type, lstm): assert value_out[stream].shape == (1,) # Test get_dist_and_value - dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories) + dists, value_out, mem_out = actor.get_dist_and_value( + [sample_obs], [], memories=memories + ) + if mem_out is not None: + assert mem_out.shape == memories.shape for dist in dists: assert isinstance(dist, GaussianDistInstance) for stream in stream_names: diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 15f0d92e2d..aff61ce0f6 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -463,7 +463,7 @@ def get_dist_and_value( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) if self.use_lstm: - mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1) + mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return dists, value_outputs, mem_out From 07bb4c0818887dab74e75236f51f6ca1533bde80 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 11 Aug 2020 19:15:30 -0700 Subject: [PATCH 05/27] Use loss masks in PPO. --- .../mlagents/trainers/ppo/optimizer_torch.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 9bbb7b51c8..13fdc12e01 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -61,12 +61,15 @@ def ppo_value_loss( old_values: Dict[str, torch.Tensor], returns: Dict[str, torch.Tensor], epsilon: float, + loss_masks: torch.Tensor, ) -> torch.Tensor: """ - Creates training-specific Tensorflow ops for PPO models. - :param returns: - :param old_values: - :param values: + Evaluates value loss for PPO. + :param values: Value output of the current network. + :param old_values: Value stored with experiences in buffer. + :param returns: Computed returns. + :param epsilon: Clipping value for value estimate. + :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ value_losses = [] for name, head in values.items(): @@ -77,18 +80,25 @@ def ppo_value_loss( ) v_opt_a = (returns_tensor - head) ** 2 v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 - value_loss = torch.mean(torch.max(v_opt_a, v_opt_b)) + masked_loss = torch.max(v_opt_a, v_opt_b) * loss_masks + value_loss = torch.mean(masked_loss) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss - def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks): + def ppo_policy_loss( + self, + advantages: torch.Tensor, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + loss_masks: torch.Tensor, + ) -> torch.Tensor: """ - Creates training-specific Tensorflow ops for PPO models. - :param masks: - :param advantages: + Evaluate PPO policy loss. + :param advantages: Computed advantages. :param log_probs: Current policy probabilities :param old_log_probs: Past policy probabilities + :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ advantage = advantages.unsqueeze(-1) @@ -99,7 +109,8 @@ def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks): p_opt_b = ( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage ) - policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b)) + masked_loss = torch.min(p_opt_a, p_opt_b) * loss_masks + policy_loss = -torch.mean(masked_loss) return policy_loss @timed @@ -153,14 +164,21 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps) + loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) + value_loss = self.ppo_value_loss( + values, old_values, returns, decay_eps, loss_masks + ) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), - ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), + loss_masks, + ) + loss = ( + policy_loss + + 0.5 * value_loss + - decay_bet * torch.mean(entropy * loss_masks) ) - loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) From 0a3c795cee4197003d5c7f5e96b03925e14f9896 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 11 Aug 2020 19:46:34 -0700 Subject: [PATCH 06/27] Proper shape of masks --- ml-agents/mlagents/trainers/ppo/optimizer_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 13fdc12e01..1232429f36 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -109,7 +109,7 @@ def ppo_policy_loss( p_opt_b = ( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage ) - masked_loss = torch.min(p_opt_a, p_opt_b) * loss_masks + masked_loss = torch.min(p_opt_a, p_opt_b).flatten() * loss_masks policy_loss = -torch.mean(masked_loss) return policy_loss @@ -164,7 +164,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) + loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.float32) value_loss = self.ppo_value_loss( values, old_values, returns, decay_eps, loss_masks ) @@ -177,7 +177,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: loss = ( policy_loss + 0.5 * value_loss - - decay_bet * torch.mean(entropy * loss_masks) + - decay_bet * torch.mean(entropy.flatten() * loss_masks) ) # Set optimizer learning rate From 2337d15c09e2cdd91ec5d0c08eb63a2ba5d5e297 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 11 Aug 2020 20:04:48 -0700 Subject: [PATCH 07/27] Proper mask mean for PPO --- ml-agents/mlagents/trainers/ppo/optimizer_torch.py | 14 +++++++------- ml-agents/mlagents/trainers/torch/utils.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 1232429f36..b330bce0fb 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -80,8 +80,7 @@ def ppo_value_loss( ) v_opt_a = (returns_tensor - head) ** 2 v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 - masked_loss = torch.max(v_opt_a, v_opt_b) * loss_masks - value_loss = torch.mean(masked_loss) + value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss @@ -109,8 +108,9 @@ def ppo_policy_loss( p_opt_b = ( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage ) - masked_loss = torch.min(p_opt_a, p_opt_b).flatten() * loss_masks - policy_loss = -torch.mean(masked_loss) + policy_loss = -1 * ModelUtils.masked_mean( + torch.min(p_opt_a, p_opt_b).flatten(), loss_masks + ) return policy_loss @timed @@ -138,7 +138,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) else: - actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) + actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.bool) memories = [ ModelUtils.list_to_tensor(batch["memory"][i]) @@ -164,7 +164,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) - loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.float32) + loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) value_loss = self.ppo_value_loss( values, old_values, returns, decay_eps, loss_masks ) @@ -177,7 +177,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: loss = ( policy_loss + 0.5 * value_loss - - decay_bet * torch.mean(entropy.flatten() * loss_masks) + - decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks) ) # Set optimizer learning rate diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index baa99c98ef..ba6b7d57a0 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -284,3 +284,13 @@ def get_probs_and_entropy( else: all_probs = torch.cat(all_probs_list, dim=-1) return log_probs, entropies, all_probs + + @staticmethod + def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: + """ + Returns the mean of the tensor but ignoring the values specified by masks. + Used for masking out loss functions. + :param tensor: Tensor which needs mean computation. + :param masks: Boolean tensor of masks with same dimension as tensor. + """ + return (tensor * masks).sum() / masks.float().sum() From 1f69102f45dd5bfd7a0c8e63fdee701af7a25632 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 11 Aug 2020 20:17:47 -0700 Subject: [PATCH 08/27] Fix dtype for actions --- ml-agents/mlagents/trainers/ppo/optimizer_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index b330bce0fb..e162166481 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -138,7 +138,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) else: - actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.bool) + actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) memories = [ ModelUtils.list_to_tensor(batch["memory"][i]) From c0a77f76b6f9b1c6a3187b9c56e49a7228605e7b Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 14:13:17 -0700 Subject: [PATCH 09/27] Proper initialization and SAC masking --- .../mlagents/trainers/sac/optimizer_torch.py | 27 +++++++++---------- ml-agents/mlagents/trainers/torch/layers.py | 24 +++++++++++++++++ ml-agents/mlagents/trainers/torch/networks.py | 13 +++++---- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 3ba92e6c1b..6d110f99f5 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -194,11 +194,11 @@ def sac_q_loss( * self.gammas[i] * target_values[name] ) - _q1_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream) + _q1_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks ) - _q2_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream) + _q2_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks ) q1_losses.append(_q1_loss) @@ -258,8 +258,8 @@ def sac_value_loss( v_backup = min_policy_qs[name] - torch.sum( _ent_coef * log_probs, dim=1 ) - value_loss = 0.5 * torch.mean( - loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) + value_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(values[name], v_backup), loss_masks ) value_losses.append(value_loss) else: @@ -278,9 +278,9 @@ def sac_value_loss( v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0 ) - value_loss = 0.5 * torch.mean( - loss_masks - * torch.nn.functional.mse_loss(values[name], v_backup.squeeze()) + value_loss = 0.5 * ModelUtils.masked_mean( + torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), + loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) @@ -300,7 +300,7 @@ def sac_policy_loss( if not discrete: mean_q1 = mean_q1.unsqueeze(1) batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) - policy_loss = torch.mean(loss_masks * batch_policy_loss) + policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) else: action_probs = log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( @@ -347,9 +347,8 @@ def sac_entropy_loss( target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) - entropy_loss = -torch.mean( - loss_masks - * torch.mean(self._log_ent_coef * target_current_diff, axis=1) + entropy_loss = -1 * ModelUtils.masked_mean( + torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks ) return entropy_loss @@ -497,7 +496,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=next_memories, sequence_length=self.policy.sequence_length, ) - masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) + masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) use_discrete = not self.policy.use_continuous_act dones = ModelUtils.list_to_tensor(batch["done"]) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 8dbb1cbcb4..d1c68887df 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -46,3 +46,27 @@ def linear_layer( layer.weight.data *= kernel_gain _init_methods[bias_init](layer.bias.data) return layer + + +def lstm_layer( + input_size: int, + hidden_size: int, + num_layers: int = 1, + batch_first: bool = True, + forget_bias: float = 1.0, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + bias_init: Initialization = Initialization.Zero, +) -> torch.nn.Module: + """ + Creates a torch.nn.LSTM and initializes its weights and biases. Provides a + forget_bias offset like is done in TensorFlow. + """ + lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first) + # Add forget_bias to forget gate bias + for name, param in lstm.named_parameters(): + if "weight" in name: + _init_methods[kernel_init](param.data) + elif "bias" in name: + _init_methods[bias_init](param.data) + param.data[hidden_size : 2 * hidden_size].add_(forget_bias) + return lstm diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index aff61ce0f6..b60cfbe543 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,6 +14,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads +from mlagents.trainers.torch.layers import lstm_layer ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -50,7 +51,7 @@ def __init__( ) if self.use_lstm: - self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) + self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) else: self.lstm = None @@ -101,13 +102,11 @@ def forward( raise Exception("No valid inputs to network.") if self.use_lstm: - encoding = encoding.view([sequence_length, -1, self.h_size]) + # Resize to (batch, sequence length, encoding size) + encoding = encoding.reshape([-1, sequence_length, self.h_size]) memories = torch.split(memories, self.m_size // 2, dim=-1) - encoding, memories = self.lstm( - encoding.contiguous(), - (memories[0].contiguous(), memories[1].contiguous()), - ) - encoding = encoding.view([-1, self.m_size // 2]) + encoding, memories = self.lstm(encoding, (memories[0], memories[1])) + encoding = encoding.reshape([-1, self.m_size // 2]) memories = torch.cat(memories, dim=-1) return encoding, memories From f404834ea551eed1ec649af72591ce59ba1cf8d8 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 15:53:13 -0700 Subject: [PATCH 10/27] Experimental amrl layer --- ml-agents/mlagents/trainers/torch/layers.py | 39 +++++++++++++++++++ ml-agents/mlagents/trainers/torch/networks.py | 6 +-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index d1c68887df..4a8ed07374 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -70,3 +70,42 @@ def lstm_layer( _init_methods[bias_init](param.data) param.data[hidden_size : 2 * hidden_size].add_(forget_bias) return lstm + + +class AMRLMax(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + batch_first: bool = True, + forget_bias: float = 1.0, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + bias_init: Initialization = Initialization.Zero, + ): + super().__init__() + self.lstm = lstm_layer( + input_size, + hidden_size, + num_layers, + batch_first, + forget_bias, + kernel_init, + bias_init, + ) + self.hidden_size = hidden_size + + def forward(self, input_tensor, h0_c0): + hidden = h0_c0 + all_out = [] + m = None + for t in range(input_tensor.shape[1]): + out, hidden = self.lstm(input_tensor[:, t : t + 1, :], hidden) + h_half, other_half = torch.split(out, self.hidden_size // 2, dim=-1) + if m is None: + m = h_half + else: + m = torch.max(m, h_half) + out = torch.cat([m, other_half]) + all_out.append(out) + return torch.cat(all_out, dim=1), hidden diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index b60cfbe543..264cef52a7 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,7 +14,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.layers import lstm_layer +from mlagents.trainers.torch.layers import AMRLMax ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -51,9 +51,9 @@ def __init__( ) if self.use_lstm: - self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) + self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True) else: - self.lstm = None + self.lstm = None # type: ignore def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): From beab310ec471d2aa06b534a38e6f04965b76e3a4 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 16:13:59 -0700 Subject: [PATCH 11/27] Add extra FF layer --- ml-agents/mlagents/trainers/torch/layers.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 4a8ed07374..bb94a28edd 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -82,6 +82,7 @@ def __init__( forget_bias: float = 1.0, kernel_init: Initialization = Initialization.XavierGlorotUniform, bias_init: Initialization = Initialization.Zero, + num_post_layers: int = 1, ): super().__init__() self.lstm = lstm_layer( @@ -94,6 +95,18 @@ def __init__( bias_init, ) self.hidden_size = hidden_size + self.layers = [] + for _ in range(num_post_layers): + self.layers.append( + linear_layer( + input_size, + hidden_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) + ) + self.layers.append(Swish()) + self.seq_layers = torch.nn.Sequential(*self.layers) def forward(self, input_tensor, h0_c0): hidden = h0_c0 @@ -108,4 +121,5 @@ def forward(self, input_tensor, h0_c0): m = torch.max(m, h_half) out = torch.cat([m, other_half]) all_out.append(out) - return torch.cat(all_out, dim=1), hidden + full_out = self.seq_layers(torch.cat(all_out, dim=1)) + return full_out, hidden From 6fece65403d59d23a4f220966a61d4c9b3cdd14c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 16:58:30 -0700 Subject: [PATCH 12/27] Faster implementation --- ml-agents/mlagents/trainers/torch/layers.py | 25 ++++++++++++--------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index bb94a28edd..ce56347c29 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -99,7 +99,7 @@ def __init__( for _ in range(num_post_layers): self.layers.append( linear_layer( - input_size, + hidden_size, hidden_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=1.0, @@ -110,16 +110,19 @@ def __init__( def forward(self, input_tensor, h0_c0): hidden = h0_c0 - all_out = [] + all_c = [] m = None - for t in range(input_tensor.shape[1]): - out, hidden = self.lstm(input_tensor[:, t : t + 1, :], hidden) - h_half, other_half = torch.split(out, self.hidden_size // 2, dim=-1) + lstm_out, hidden = self.lstm(input_tensor, hidden) + h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1) + for t in range(h_half.shape[1]): + h_half_subt = h_half[:, t : t + 1, :] if m is None: - m = h_half + m = h_half_subt else: - m = torch.max(m, h_half) - out = torch.cat([m, other_half]) - all_out.append(out) - full_out = self.seq_layers(torch.cat(all_out, dim=1)) - return full_out, hidden + m = torch.max(m, h_half_subt) + all_c.append(m) + concat_c = torch.cat(all_c, dim=1) + concat_out = torch.cat([concat_c, other_half], dim=-1) + full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) + full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) + return concat_out, hidden From eac1dc96dbebb443e5dba1e6846eb8a24568eee8 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 17:11:56 -0700 Subject: [PATCH 13/27] Add comment --- ml-agents/mlagents/trainers/torch/layers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index ce56347c29..58c549f0d5 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -73,6 +73,11 @@ def lstm_layer( class AMRLMax(torch.nn.Module): + """ + Implements Aggregation for LSTM as described here: + https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/ + """ + def __init__( self, input_size: int, From d2e31aadf26145b01a4419b8aea9a7f19c9a174f Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 12 Aug 2020 18:16:44 -0700 Subject: [PATCH 14/27] Passthrough max --- ml-agents/mlagents/trainers/torch/layers.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 58c549f0d5..eba616cd19 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -124,10 +124,19 @@ def forward(self, input_tensor, h0_c0): if m is None: m = h_half_subt else: - m = torch.max(m, h_half_subt) + m = AMRLMax.PassthroughMax.apply(m, h_half_subt) all_c.append(m) concat_c = torch.cat(all_c, dim=1) concat_out = torch.cat([concat_c, other_half], dim=-1) full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) return concat_out, hidden + + class PassthroughMax(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor1, tensor2): + return torch.max(tensor1, tensor2) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.clone(), grad_output.clone() From bf485a2006adae9595b42c8be461b2b71d7c614c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 14 Aug 2020 11:59:31 -0700 Subject: [PATCH 15/27] Memory size abstraction and fixes --- .../mlagents/trainers/policy/torch_policy.py | 1 + ml-agents/mlagents/trainers/torch/layers.py | 26 ++++++++++++------- ml-agents/mlagents/trainers/torch/networks.py | 26 ++++++++++++++----- 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 9d935d4676..4d645e1fba 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -83,6 +83,7 @@ def __init__( conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) + self.m_size = self.actor_critic.memory_size self.actor_critic.to(TestingConfiguration.device) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index b1303338fb..2d353db7df 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -125,24 +125,32 @@ def __init__( self.layers.append(Swish()) self.seq_layers = torch.nn.Sequential(*self.layers) - def forward(self, input_tensor, h0_c0): - hidden = h0_c0 + @property + def memory_size(self) -> int: + return self.hidden_size // 2 + 2 * self.hidden_size + + def forward(self, input_tensor, memories): + # memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0) + acc, h0, c0 = torch.split( + memories, + [self.hidden_size // 2, self.hidden_size, self.hidden_size], + dim=-1, + ) + hidden = (h0, c0) all_c = [] - m = None - lstm_out, hidden = self.lstm(input_tensor, hidden) + m = acc.permute([1, 0, 2]) + lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden) h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1) for t in range(h_half.shape[1]): h_half_subt = h_half[:, t : t + 1, :] - if m is None: - m = h_half_subt - else: - m = AMRLMax.PassthroughMax.apply(m, h_half_subt) + m = AMRLMax.PassthroughMax.apply(m, h_half_subt) all_c.append(m) concat_c = torch.cat(all_c, dim=1) concat_out = torch.cat([concat_c, other_half], dim=-1) full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) - return concat_out, hidden + output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1) + return concat_out, output_mem class PassthroughMax(torch.autograd.Function): @staticmethod diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 591568aece..87ff8c84cc 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,7 +14,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.layers import lstm_layer +from mlagents.trainers.torch.layers import AMRLMax ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -51,7 +51,7 @@ def __init__( ) if self.use_lstm: - self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) + self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True) else: self.lstm = None # type: ignore @@ -104,10 +104,10 @@ def forward( if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) - memories = torch.split(memories, self.m_size // 2, dim=-1) + # memories = torch.split(memories, self.m_size // 2, dim=-1) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) - memories = torch.cat(memories, dim=-1) + # memories = torch.cat(memories, dim=-1) return encoding, memories @@ -257,7 +257,7 @@ def __init__( self.act_type = act_type self.act_size = act_size self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) - self.memory_size = torch.nn.Parameter(torch.Tensor([0])) + self.memory_size_param = torch.nn.Parameter(torch.Tensor([0])) self.is_continuous_int = torch.nn.Parameter( torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) ) @@ -279,6 +279,13 @@ def __init__( self.encoding_size, act_size ) + @property + def memory_size(self) -> int: + if self.network_body.lstm is not None: + return self.network_body.lstm.memory_size + else: + return 0 + def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: self.network_body.update_normalization(vector_obs) @@ -327,7 +334,7 @@ def forward( sampled_actions, dists[0].pdf(sampled_actions), self.version_number, - self.memory_size, + self.memory_size_param, self.is_continuous_int, self.act_size_vector, ) @@ -425,6 +432,13 @@ def __init__( stream_names, observation_shapes, use_network_settings ) + @property + def memory_size(self) -> int: + if self.network_body.lstm is not None: + return 2 * self.network_body.lstm.memory_size + else: + return 0 + def critic_pass( self, vec_inputs: List[torch.Tensor], From bd90e29c7b402f6a838d52d5b8baf645bc1bd16b Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 14 Aug 2020 14:25:28 -0700 Subject: [PATCH 16/27] Fix SeparateActorCritic --- ml-agents/mlagents/trainers/torch/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 87ff8c84cc..8d73912667 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -449,7 +449,7 @@ def critic_pass( actor_mem, critic_mem = None, None if self.use_lstm: # Use only the back half of memories for critic - actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) + actor_mem, critic_mem = torch.split(memories, self.memory_size, -1) value_outputs, critic_mem_out = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) @@ -470,7 +470,7 @@ def get_dist_and_value( ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: if self.use_lstm: # Use only the back half of memories for critic and actor - actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) + actor_mem, critic_mem = torch.split(memories, self.memory_size, dim=-1) else: critic_mem = None actor_mem = None From 7f4ea51b226759c3c620bc1aa2772fef6ead2b5a Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Fri, 14 Aug 2020 15:00:17 -0700 Subject: [PATCH 17/27] Fix SeparateActorCritic --- ml-agents/mlagents/trainers/torch/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 8d73912667..f606aafd81 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -449,7 +449,7 @@ def critic_pass( actor_mem, critic_mem = None, None if self.use_lstm: # Use only the back half of memories for critic - actor_mem, critic_mem = torch.split(memories, self.memory_size, -1) + actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) value_outputs, critic_mem_out = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) @@ -470,7 +470,7 @@ def get_dist_and_value( ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: if self.use_lstm: # Use only the back half of memories for critic and actor - actor_mem, critic_mem = torch.split(memories, self.memory_size, dim=-1) + actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None From 317454a8da228f0f1762ae849accc82e40ed4caf Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 17:37:41 -0700 Subject: [PATCH 18/27] LSTM class --- ml-agents/mlagents/trainers/torch/layers.py | 85 ++++++++----------- ml-agents/mlagents/trainers/torch/networks.py | 4 +- 2 files changed, 36 insertions(+), 53 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 2d353db7df..bddced1831 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -1,4 +1,6 @@ import torch +import abc +from typing import Tuple from enum import Enum @@ -84,79 +86,60 @@ def lstm_layer( return lstm -class AMRLMax(torch.nn.Module): +class MemoryModule(torch.nn.Module): + @abc.abstractproperty + def memory_size(self) -> int: + """ + Size of memory that is required at the start of a sequence. + """ + pass + + @abc.abstractmethod + def forward( + self, input_tensor: torch.Tensor, memories: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pass a sequence to the memory module. + :input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input. + :memories: Tensor of initial memories. + :return: Tuple of output, final memories. + """ + pass + + +class LSTM(MemoryModule): """ - Implements Aggregation for LSTM as described here: - https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/ + Memory module that implements LSTM. """ def __init__( self, input_size: int, - hidden_size: int, + memory_size: int, num_layers: int = 1, - batch_first: bool = True, forget_bias: float = 1.0, kernel_init: Initialization = Initialization.XavierGlorotUniform, bias_init: Initialization = Initialization.Zero, - num_post_layers: int = 1, ): super().__init__() + self.hidden_size = memory_size // 2 self.lstm = lstm_layer( input_size, - hidden_size, + self.hidden_size, num_layers, - batch_first, + True, forget_bias, kernel_init, bias_init, ) - self.hidden_size = hidden_size - self.layers = [] - for _ in range(num_post_layers): - self.layers.append( - linear_layer( - hidden_size, - hidden_size, - kernel_init=Initialization.KaimingHeNormal, - kernel_gain=1.0, - ) - ) - self.layers.append(Swish()) - self.seq_layers = torch.nn.Sequential(*self.layers) @property def memory_size(self) -> int: - return self.hidden_size // 2 + 2 * self.hidden_size + return 2 * self.hidden_size def forward(self, input_tensor, memories): - # memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0) - acc, h0, c0 = torch.split( - memories, - [self.hidden_size // 2, self.hidden_size, self.hidden_size], - dim=-1, - ) + h0, c0 = torch.split(memories, self.hidden_size, dim=-1) hidden = (h0, c0) - all_c = [] - m = acc.permute([1, 0, 2]) - lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden) - h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1) - for t in range(h_half.shape[1]): - h_half_subt = h_half[:, t : t + 1, :] - m = AMRLMax.PassthroughMax.apply(m, h_half_subt) - all_c.append(m) - concat_c = torch.cat(all_c, dim=1) - concat_out = torch.cat([concat_c, other_half], dim=-1) - full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) - full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) - output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1) - return concat_out, output_mem - - class PassthroughMax(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor1, tensor2): - return torch.max(tensor1, tensor2) - - @staticmethod - def backward(ctx, grad_output): - return grad_output.clone(), grad_output.clone() + lstm_out, hidden_out = self.lstm(input_tensor, hidden) + output_mem = torch.cat(hidden_out, dim=-1) + return lstm_out, output_mem diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index f606aafd81..a179416989 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -14,7 +14,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.layers import AMRLMax +from mlagents.trainers.torch.layers import LSTM ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -51,7 +51,7 @@ def __init__( ) if self.use_lstm: - self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True) + self.lstm = LSTM(self.h_size, self.m_size) else: self.lstm = None # type: ignore From 848a87545b2b6f6d95d95beadc55c75eec806329 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 17:51:12 -0700 Subject: [PATCH 19/27] Fix SeparateActorCritic export --- .../mlagents/trainers/policy/torch_policy.py | 11 +++++++ .../trainers/torch/model_serialization.py | 2 +- ml-agents/mlagents/trainers/torch/networks.py | 31 ++++++++----------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 6144c68c42..d1863cdc0d 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -82,10 +82,21 @@ def __init__( conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) + # Save the m_size needed for export + self._export_m_size = self.m_size + # m_size needed for training is determined by network, not trainer settings self.m_size = self.actor_critic.memory_size self.actor_critic.to(TestingConfiguration.device) + @property + def export_memory_size(self) -> int: + """ + Returns the memory size of the exported ONNX policy. This only includes the memory + of the Actor and not any auxillary networks. + """ + return self._export_m_size + def split_decision_step(self, decision_requests): vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) mask = None diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 5e1f16cdec..e8cbd2672f 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -19,7 +19,7 @@ def __init__(self, policy): else [] ) dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) - dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) + dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size]) # Need to pass all posslible inputs since currently keyword arguments is not # supported by torch.nn.export() diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 9ce482ef0a..130824bc3c 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -1,5 +1,4 @@ from typing import Callable, List, Dict, Tuple, Optional -import attr import abc import torch @@ -99,10 +98,8 @@ def forward( if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) - # memories = torch.split(memories, self.m_size // 2, dim=-1) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) - # memories = torch.cat(memories, dim=-1) return encoding, memories @@ -407,29 +404,27 @@ def __init__( # Give the Actor only half the memories. Note we previously validate # that memory_size must be a multiple of 4. self.use_lstm = network_settings.memory is not None - if network_settings.memory is not None: - self.half_mem_size = network_settings.memory.memory_size // 2 - new_memory_settings = attr.evolve( - network_settings.memory, memory_size=self.half_mem_size - ) - use_network_settings = attr.evolve( - network_settings, memory=new_memory_settings - ) - else: - use_network_settings = network_settings - self.half_mem_size = 0 + # if network_settings.memory is not None: + # self.half_mem_size = network_settings.memory.memory_size // 2 + # new_memory_settings = attr.evolve( + # network_settings.memory, memory_size=self.half_mem_size + # ) + # use_network_settings = attr.evolve( + # network_settings, memory=new_memory_settings + # ) + # else: + # use_network_settings = network_settings + # self.half_mem_size = 0 super().__init__( observation_shapes, - use_network_settings, + network_settings, act_type, act_size, conditional_sigma, tanh_squash, ) self.stream_names = stream_names - self.critic = ValueNetwork( - stream_names, observation_shapes, use_network_settings - ) + self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) @property def memory_size(self) -> int: From 295094b376657a1b1bb3199606dae1e7c2057b68 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 17:58:19 -0700 Subject: [PATCH 20/27] Add abstract method to Actor --- ml-agents/mlagents/trainers/torch/networks.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 130824bc3c..8b00bf9f76 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -234,6 +234,14 @@ def get_dist_and_value( """ pass + @abc.abstractproperty + def memory_size(self): + """ + Returns the size of the memory (same size used as input and output in the other + methods) used by this Actor. + """ + pass + class SimpleActor(nn.Module, Actor): def __init__( From 72bca86504f513115a31c7db1f891cbbb2e84663 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 18:06:13 -0700 Subject: [PATCH 21/27] Fix BC module --- ml-agents/mlagents/trainers/torch/components/bc/module.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/components/bc/module.py b/ml-agents/mlagents/trainers/torch/components/bc/module.py index 13e251a49f..61f7f03758 100644 --- a/ml-agents/mlagents/trainers/torch/components/bc/module.py +++ b/ml-agents/mlagents/trainers/torch/components/bc/module.py @@ -150,9 +150,7 @@ def _update_batch( memories = [] if self.policy.use_recurrent: - memories = torch.zeros( - 1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2 - ) + memories = torch.zeros(1, self.n_sequences, self.policy.m_size) if self.policy.use_vis_obs: vis_obs = [] From 683001ff5e5f6711808be97741714d7116293f6d Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 18:07:55 -0700 Subject: [PATCH 22/27] Remove some comments --- ml-agents/mlagents/trainers/torch/networks.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 8b00bf9f76..2d5869fcdb 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -412,17 +412,6 @@ def __init__( # Give the Actor only half the memories. Note we previously validate # that memory_size must be a multiple of 4. self.use_lstm = network_settings.memory is not None - # if network_settings.memory is not None: - # self.half_mem_size = network_settings.memory.memory_size // 2 - # new_memory_settings = attr.evolve( - # network_settings.memory, memory_size=self.half_mem_size - # ) - # use_network_settings = attr.evolve( - # network_settings, memory=new_memory_settings - # ) - # else: - # use_network_settings = network_settings - # self.half_mem_size = 0 super().__init__( observation_shapes, network_settings, From 77a47f4ad2dcea56b6fb23033b170168b2da05c2 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 17 Aug 2020 19:05:18 -0700 Subject: [PATCH 23/27] Fix network tests --- ml-agents/mlagents/trainers/tests/torch/test_networks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index 06f8b1ab25..2a1d9bbcd2 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -184,11 +184,7 @@ def test_actor_critic(ac_type, lstm): if lstm: sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) memories = torch.ones( - ( - 1, - network_settings.memory.sequence_length, - network_settings.memory.memory_size, - ) + (1, network_settings.memory.sequence_length, actor.memory_size) ) else: sample_obs = torch.ones((1, obs_size)) From 7dba7bf3a6e7c8e3370aaea8ff0be3f80c3dba9c Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 18 Aug 2020 10:52:51 -0700 Subject: [PATCH 24/27] Clean up memory_size logic --- ml-agents/mlagents/trainers/torch/networks.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 2d5869fcdb..6e2b756428 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -63,6 +63,10 @@ def copy_normalization(self, other_network: "NetworkBody") -> None: for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): n1.copy_normalization(n2) + @property + def memory_size(self) -> int: + return self.lstm.memory_size if self.use_lstm else 0 + def forward( self, vec_inputs: List[torch.Tensor], @@ -124,6 +128,10 @@ def __init__( encoding_size = network_settings.hidden_units self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) + @property + def memory_size(self) -> int: + return self.network_body.memory_size + def forward( self, vec_inputs: List[torch.Tensor], @@ -281,10 +289,7 @@ def __init__( @property def memory_size(self) -> int: - if self.network_body.lstm is not None: - return self.network_body.lstm.memory_size - else: - return 0 + return self.network_body.memory_size def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: self.network_body.update_normalization(vector_obs) @@ -425,10 +430,7 @@ def __init__( @property def memory_size(self) -> int: - if self.network_body.lstm is not None: - return 2 * self.network_body.lstm.memory_size - else: - return 0 + return self.network_body.memory_size + self.critic.memory_size def critic_pass( self, From 1d6a08c2b7e5b0aa20db8e0d4de1ad0275cd9a61 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 18 Aug 2020 11:03:40 -0700 Subject: [PATCH 25/27] Cleanup, add test --- .../trainers/tests/torch/test_layers.py | 19 +++++++++++++++++++ ml-agents/mlagents/trainers/torch/layers.py | 6 +++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py index 6d1132aa2e..2086d6dd13 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_layers.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -5,6 +5,7 @@ linear_layer, lstm_layer, Initialization, + LSTM, ) @@ -38,3 +39,21 @@ def test_lstm_layer(): assert torch.all( torch.eq(param.data[4:8], torch.ones_like(param.data[4:8])) ) + + +def test_lstm_class(): + torch.manual_seed(0) + input_size = 12 + memory_size = 64 + batch_size = 8 + seq_len = 16 + lstm = LSTM(input_size, memory_size) + + assert lstm.memory_size == memory_size + + sample_input = torch.ones((batch_size, seq_len, input_size)) + sample_memories = torch.ones((1, batch_size, memory_size)) + out, mem = lstm(sample_input, sample_memories) + # Hidden size should be half of memory_size + assert out.shape == (batch_size, seq_len, memory_size // 2) + assert mem.shape == (1, batch_size, memory_size) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index bddced1831..4f9c56b0a5 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -122,6 +122,8 @@ def __init__( bias_init: Initialization = Initialization.Zero, ): super().__init__() + # We set hidden size to half of memory_size since the initial memory + # will be divided between the hidden state and initial cell state. self.hidden_size = memory_size // 2 self.lstm = lstm_layer( input_size, @@ -137,7 +139,9 @@ def __init__( def memory_size(self) -> int: return 2 * self.hidden_size - def forward(self, input_tensor, memories): + def forward( + self, input_tensor: torch.Tensor, memories: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: h0, c0 = torch.split(memories, self.hidden_size, dim=-1) hidden = (h0, c0) lstm_out, hidden_out = self.lstm(input_tensor, hidden) From 9bb065c0607a053a1d7428494f308e1008b378d8 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 18 Aug 2020 16:23:24 -0700 Subject: [PATCH 26/27] Properly export memory size --- ml-agents/mlagents/trainers/torch/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 6e2b756428..b579316c0b 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -265,7 +265,6 @@ def __init__( self.act_type = act_type self.act_size = act_size self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) - self.memory_size_param = torch.nn.Parameter(torch.Tensor([0])) self.is_continuous_int = torch.nn.Parameter( torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) ) @@ -275,6 +274,8 @@ def __init__( self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units + self.memory_size_param = torch.nn.Parameter(torch.Tensor([self.memory_size])) + if self.act_type == ActionType.CONTINUOUS: self.distribution = GaussianDistribution( self.encoding_size, From 6eb09d303726d774a7c0fd1a98d3a896111a7d9a Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Tue, 18 Aug 2020 18:19:01 -0700 Subject: [PATCH 27/27] Fix exporting again --- ml-agents/mlagents/trainers/torch/networks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index b579316c0b..585fcfa7f7 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -274,7 +274,6 @@ def __init__( self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units - self.memory_size_param = torch.nn.Parameter(torch.Tensor([self.memory_size])) if self.act_type == ActionType.CONTINUOUS: self.distribution = GaussianDistribution( @@ -344,7 +343,7 @@ def forward( sampled_actions, log_probs, self.version_number, - self.memory_size_param, + torch.Tensor([self.network_body.memory_size]), self.is_continuous_int, self.act_size_vector, )