diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 6d2afc72eb..0abdbddeaa 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -79,9 +79,21 @@ def __init__( conditional_sigma=self.condition_sigma_on_obs, tanh_squash=tanh_squash, ) + # Save the m_size needed for export + self._export_m_size = self.m_size + # m_size needed for training is determined by network, not trainer settings + self.m_size = self.actor_critic.memory_size self.actor_critic.to("cpu") + @property + def export_memory_size(self) -> int: + """ + Returns the memory size of the exported ONNX policy. This only includes the memory + of the Actor and not any auxillary networks. + """ + return self._export_m_size + def _split_decision_step( self, decision_requests: DecisionSteps ) -> Tuple[SplitObservations, np.ndarray]: diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py index 6d1132aa2e..2086d6dd13 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_layers.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -5,6 +5,7 @@ linear_layer, lstm_layer, Initialization, + LSTM, ) @@ -38,3 +39,21 @@ def test_lstm_layer(): assert torch.all( torch.eq(param.data[4:8], torch.ones_like(param.data[4:8])) ) + + +def test_lstm_class(): + torch.manual_seed(0) + input_size = 12 + memory_size = 64 + batch_size = 8 + seq_len = 16 + lstm = LSTM(input_size, memory_size) + + assert lstm.memory_size == memory_size + + sample_input = torch.ones((batch_size, seq_len, input_size)) + sample_memories = torch.ones((1, batch_size, memory_size)) + out, mem = lstm(sample_input, sample_memories) + # Hidden size should be half of memory_size + assert out.shape == (batch_size, seq_len, memory_size // 2) + assert mem.shape == (1, batch_size, memory_size) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index 06f8b1ab25..2a1d9bbcd2 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -184,11 +184,7 @@ def test_actor_critic(ac_type, lstm): if lstm: sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) memories = torch.ones( - ( - 1, - network_settings.memory.sequence_length, - network_settings.memory.memory_size, - ) + (1, network_settings.memory.sequence_length, actor.memory_size) ) else: sample_obs = torch.ones((1, obs_size)) diff --git a/ml-agents/mlagents/trainers/torch/components/bc/module.py b/ml-agents/mlagents/trainers/torch/components/bc/module.py index 13e251a49f..61f7f03758 100644 --- a/ml-agents/mlagents/trainers/torch/components/bc/module.py +++ b/ml-agents/mlagents/trainers/torch/components/bc/module.py @@ -150,9 +150,7 @@ def _update_batch( memories = [] if self.policy.use_recurrent: - memories = torch.zeros( - 1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2 - ) + memories = torch.zeros(1, self.n_sequences, self.policy.m_size) if self.policy.use_vis_obs: vis_obs = [] diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index 707d4748a5..4f9c56b0a5 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -1,4 +1,6 @@ import torch +import abc +from typing import Tuple from enum import Enum @@ -82,3 +84,66 @@ def lstm_layer( forget_bias ) return lstm + + +class MemoryModule(torch.nn.Module): + @abc.abstractproperty + def memory_size(self) -> int: + """ + Size of memory that is required at the start of a sequence. + """ + pass + + @abc.abstractmethod + def forward( + self, input_tensor: torch.Tensor, memories: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pass a sequence to the memory module. + :input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input. + :memories: Tensor of initial memories. + :return: Tuple of output, final memories. + """ + pass + + +class LSTM(MemoryModule): + """ + Memory module that implements LSTM. + """ + + def __init__( + self, + input_size: int, + memory_size: int, + num_layers: int = 1, + forget_bias: float = 1.0, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + bias_init: Initialization = Initialization.Zero, + ): + super().__init__() + # We set hidden size to half of memory_size since the initial memory + # will be divided between the hidden state and initial cell state. + self.hidden_size = memory_size // 2 + self.lstm = lstm_layer( + input_size, + self.hidden_size, + num_layers, + True, + forget_bias, + kernel_init, + bias_init, + ) + + @property + def memory_size(self) -> int: + return 2 * self.hidden_size + + def forward( + self, input_tensor: torch.Tensor, memories: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + h0, c0 = torch.split(memories, self.hidden_size, dim=-1) + hidden = (h0, c0) + lstm_out, hidden_out = self.lstm(input_tensor, hidden) + output_mem = torch.cat(hidden_out, dim=-1) + return lstm_out, output_mem diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 5e1f16cdec..e8cbd2672f 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -19,7 +19,7 @@ def __init__(self, policy): else [] ) dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) - dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) + dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size]) # Need to pass all posslible inputs since currently keyword arguments is not # supported by torch.nn.export() diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 41c8536e43..585fcfa7f7 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -1,5 +1,4 @@ from typing import Callable, List, Dict, Tuple, Optional -import attr import abc import torch @@ -14,7 +13,7 @@ from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.decoders import ValueHeads -from mlagents.trainers.torch.layers import lstm_layer +from mlagents.trainers.torch.layers import LSTM ActivationFunction = Callable[[torch.Tensor], torch.Tensor] EncoderFunction = Callable[ @@ -51,9 +50,9 @@ def __init__( ) if self.use_lstm: - self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) + self.lstm = LSTM(self.h_size, self.m_size) else: - self.lstm = None + self.lstm = None # type: ignore def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): @@ -64,6 +63,10 @@ def copy_normalization(self, other_network: "NetworkBody") -> None: for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): n1.copy_normalization(n2) + @property + def memory_size(self) -> int: + return self.lstm.memory_size if self.use_lstm else 0 + def forward( self, vec_inputs: List[torch.Tensor], @@ -99,10 +102,8 @@ def forward( if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) - memories = torch.split(memories, self.m_size // 2, dim=-1) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) - memories = torch.cat(memories, dim=-1) return encoding, memories @@ -127,6 +128,10 @@ def __init__( encoding_size = network_settings.hidden_units self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) + @property + def memory_size(self) -> int: + return self.network_body.memory_size + def forward( self, vec_inputs: List[torch.Tensor], @@ -237,6 +242,14 @@ def get_dist_and_value( """ pass + @abc.abstractproperty + def memory_size(self): + """ + Returns the size of the memory (same size used as input and output in the other + methods) used by this Actor. + """ + pass + class SimpleActor(nn.Module, Actor): def __init__( @@ -252,7 +265,6 @@ def __init__( self.act_type = act_type self.act_size = act_size self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) - self.memory_size = torch.nn.Parameter(torch.Tensor([0])) self.is_continuous_int = torch.nn.Parameter( torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) ) @@ -262,6 +274,7 @@ def __init__( self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units + if self.act_type == ActionType.CONTINUOUS: self.distribution = GaussianDistribution( self.encoding_size, @@ -274,6 +287,10 @@ def __init__( self.encoding_size, act_size ) + @property + def memory_size(self) -> int: + return self.network_body.memory_size + def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: self.network_body.update_normalization(vector_obs) @@ -326,7 +343,7 @@ def forward( sampled_actions, log_probs, self.version_number, - self.memory_size, + torch.Tensor([self.network_body.memory_size]), self.is_continuous_int, self.act_size_vector, ) @@ -400,29 +417,20 @@ def __init__( # Give the Actor only half the memories. Note we previously validate # that memory_size must be a multiple of 4. self.use_lstm = network_settings.memory is not None - if network_settings.memory is not None: - self.half_mem_size = network_settings.memory.memory_size // 2 - new_memory_settings = attr.evolve( - network_settings.memory, memory_size=self.half_mem_size - ) - use_network_settings = attr.evolve( - network_settings, memory=new_memory_settings - ) - else: - use_network_settings = network_settings - self.half_mem_size = 0 super().__init__( observation_shapes, - use_network_settings, + network_settings, act_type, act_size, conditional_sigma, tanh_squash, ) self.stream_names = stream_names - self.critic = ValueNetwork( - stream_names, observation_shapes, use_network_settings - ) + self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) + + @property + def memory_size(self) -> int: + return self.network_body.memory_size + self.critic.memory_size def critic_pass( self, @@ -434,7 +442,7 @@ def critic_pass( actor_mem, critic_mem = None, None if self.use_lstm: # Use only the back half of memories for critic - actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) + actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) value_outputs, critic_mem_out = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) @@ -455,7 +463,7 @@ def get_dist_and_value( ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: if self.use_lstm: # Use only the back half of memories for critic and actor - actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) + actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None