From 482370e654dd408c367203cbfc2ced98bf6e419b Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 27 Jul 2020 14:16:51 -0700 Subject: [PATCH 1/5] Refactor normalizers and encoders --- .../mlagents/trainers/sac/optimizer_torch.py | 4 +- ml-agents/mlagents/trainers/torch/encoders.py | 95 ++++++++++--- ml-agents/mlagents/trainers/torch/networks.py | 129 +++++------------- ml-agents/mlagents/trainers/torch/utils.py | 28 ++-- 4 files changed, 126 insertions(+), 130 deletions(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 4e6c781397..bde8b14b80 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -370,10 +370,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: next_vis_obs.append(next_vis_ob) # Copy normalizers from policy - self.value_network.q1_network.copy_normalization( + self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) - self.value_network.q2_network.copy_normalization( + self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) self.target_network.network_body.copy_normalization( diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index 4a90daf673..ee56f2b256 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -1,28 +1,19 @@ -import torch -from torch import nn +from typing import Tuple, Optional +from mlagents.trainers.exception import UnityTrainerException -class VectorEncoder(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, **kwargs): - super().__init__(**kwargs) - self.layers = [nn.Linear(input_size, hidden_size)] - for _ in range(num_layers - 1): - self.layers.append(nn.Linear(hidden_size, hidden_size)) - self.layers.append(nn.ReLU()) - self.seq_layers = nn.Sequential(*self.layers) - - def forward(self, inputs): - return self.seq_layers(inputs) +import torch +from torch import nn class Normalizer(nn.Module): - def __init__(self, vec_obs_size, **kwargs): - super().__init__(**kwargs) + def __init__(self, vec_obs_size: int): + super().__init__() self.normalization_steps = torch.tensor(1) self.running_mean = torch.zeros(vec_obs_size) self.running_variance = torch.ones(vec_obs_size) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: normalized_state = torch.clamp( (inputs - self.running_mean) / torch.sqrt(self.running_variance / self.normalization_steps), @@ -31,7 +22,7 @@ def forward(self, inputs): ) return normalized_state - def update(self, vector_input): + def update(self, vector_input: torch.Tensor) -> None: steps_increment = vector_input.size()[0] total_new_steps = self.normalization_steps + steps_increment @@ -66,14 +57,78 @@ def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): return h, w -def pool_out_shape(h_w, kernel_size): +def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: height = (h_w[0] - kernel_size) // 2 + 1 width = (h_w[1] - kernel_size) // 2 + 1 return height, width +class VectorEncoder(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + normalize: bool = False, + ): + self.normalizer: Optional[Normalizer] = None + super().__init__() + self.layers = [nn.Linear(input_size, hidden_size)] + if normalize: + self.normalizer = Normalizer(input_size) + + for _ in range(num_layers - 1): + self.layers.append(nn.Linear(hidden_size, hidden_size)) + self.layers.append(nn.ReLU()) + self.seq_layers = nn.Sequential(*self.layers) + + def forward(self, inputs: torch.Tensor) -> None: + if self.normalizer is not None: + inputs = self.normalizer(inputs) + return self.seq_layers(inputs) + + def copy_normalization(self, other_encoder: "VectorEncoder") -> None: + if self.normalizer is not None and other_encoder.normalizer is not None: + self.normalizer.copy_from(other_encoder.normalizer) + + def update_normalization(self, inputs: torch.Tensor) -> None: + if self.normalizer is not None: + self.normalizer.update(inputs) + + +class ActionVectorEncoder(VectorEncoder): + def __init__( + self, + input_size: int, + hidden_size: int, + num_actions: int, + num_layers: int, + normalize: bool = False, + ): + super().__init__( + input_size + num_actions, hidden_size, num_layers, normalize=False + ) + if normalize: + self.normalizer = Normalizer(input_size) + else: + self.normalizer = None + + def forward( # pylint: disable=W0221 + self, inputs: torch.Tensor, actions: Optional[torch.Tensor] = None + ) -> None: + if actions is None: + raise UnityTrainerException( + "Attempted to call an ActionVectorEncoder without an action." + ) # Fix mypy errors about method parameters. + if self.normalizer is not None: + inputs = self.normalizer(inputs) + return self.seq_layers(torch.cat([inputs, actions], dim=-1)) + + class SimpleVisualEncoder(nn.Module): - def __init__(self, height, width, initial_channels, output_size): + def __init__( + self, height: int, width: int, initial_channels: int, output_size: int + ): super().__init__() self.h_size = output_size conv_1_hw = conv_output_shape((height, width), 8, 4) @@ -84,7 +139,7 @@ def __init__(self, height, width, initial_channels, output_size): self.conv2 = nn.Conv2d(16, 32, [4, 4], [2, 2]) self.dense = nn.Linear(self.final_flat, self.h_size) - def forward(self, visual_obs): + def forward(self, visual_obs: torch.Tensor) -> None: conv_1 = torch.relu(self.conv1(visual_obs)) conv_2 = torch.relu(self.conv2(conv_1)) # hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat]))) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index c0ebd1f2a4..0e472601a5 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -1,4 +1,4 @@ -from typing import Callable, NamedTuple, List, Dict, Tuple +from typing import Callable, List, Dict, Tuple, Optional import torch from torch import nn @@ -20,17 +20,12 @@ EPSILON = 1e-7 -class NormalizerTensors(NamedTuple): - steps: torch.Tensor - running_mean: torch.Tensor - running_variance: torch.Tensor - - class NetworkBody(nn.Module): def __init__( self, observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, + encoded_act_size: int = 0, ): super().__init__() self.normalize = network_settings.normalize @@ -42,16 +37,13 @@ def __init__( else 0 ) - ( - self.visual_encoders, - self.vector_encoders, - self.vector_normalizers, - ) = ModelUtils.create_encoders( + self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders( observation_shapes, self.h_size, network_settings.num_layers, network_settings.vis_encode_type, - action_size=0, + encoded_act_size=encoded_act_size, + normalize=self.normalize, ) if self.use_lstm: @@ -60,24 +52,29 @@ def __init__( self.lstm = None def update_normalization(self, vec_inputs): - if self.normalize: - for idx, vec_input in enumerate(vec_inputs): - self.vector_normalizers[idx].update(vec_input) + for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): + vec_enc.update_normalization(vec_input) def copy_normalization(self, other_network: "NetworkBody") -> None: if self.normalize: - for n1, n2 in zip( - self.vector_normalizers, other_network.vector_normalizers - ): - n1.copy_from(n2) + for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): + n1.copy_normalization(n2) - def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1): + def forward( + self, + vec_inputs: torch.Tensor, + vis_inputs: torch.Tensor, + actions: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[torch.Tensor, torch.Tensor]: vec_embeds = [] for idx, encoder in enumerate(self.vector_encoders): vec_input = vec_inputs[idx] - if self.normalize: - vec_input = self.vector_normalizers[idx](vec_input) - hidden = encoder(vec_input) + if actions is not None: + hidden = encoder(vec_input, actions) + else: + hidden = encoder(vec_input) vec_embeds.append(hidden) vis_embeds = [] @@ -113,7 +110,7 @@ def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1): return embedding, memories -class QNetwork(NetworkBody): +class QNetwork(nn.Module): def __init__( # pylint: disable=W0231 self, stream_names: List[str], @@ -122,89 +119,31 @@ def __init__( # pylint: disable=W0231 act_type: ActionType, act_size: List[int], ): + # This is not a typo, we want to call __init__ of nn.Module nn.Module.__init__(self) - self.normalize = network_settings.normalize - self.use_lstm = network_settings.memory is not None - self.h_size = network_settings.hidden_units - self.m_size = ( - network_settings.memory.memory_size - if network_settings.memory is not None - else 0 - ) - - ( - self.visual_encoders, - self.vector_encoders, - self.vector_normalizers, - ) = ModelUtils.create_encoders( - observation_shapes, - self.h_size, - network_settings.num_layers, - network_settings.vis_encode_type, - action_size=sum(act_size) if act_type == ActionType.CONTINUOUS else 0, - ) - - if self.use_lstm: - self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) - else: - self.lstm = None if act_type == ActionType.DISCRETE: + self.network_body = NetworkBody(observation_shapes, network_settings) self.q_heads = ValueHeads( stream_names, network_settings.hidden_units, sum(act_size) ) else: + self.network_body = NetworkBody( + observation_shapes, network_settings, encoded_act_size=sum(act_size) + ) self.q_heads = ValueHeads(stream_names, network_settings.hidden_units) def forward( # pylint: disable=W0221 self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], - memories: torch.Tensor = None, + actions: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - actions: torch.Tensor = None, ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: - vec_embeds = [] - for i, (enc, norm) in enumerate( - zip(self.vector_encoders, self.vector_normalizers) - ): - vec_input = vec_inputs[i] - if self.normalize: - vec_input = norm(vec_input) - if actions is not None: - hidden = enc(torch.cat([vec_input, actions], dim=-1)) - else: - hidden = enc(vec_input) - vec_embeds.append(hidden) - - vis_embeds = [] - for idx, encoder in enumerate(self.visual_encoders): - vis_input = vis_inputs[idx] - vis_input = vis_input.permute([0, 3, 1, 2]) - hidden = encoder(vis_input) - vis_embeds.append(hidden) - - # embedding = vec_embeds[0] - if len(vec_embeds) > 0 and len(vis_embeds) > 0: - vec_embeds_tensor = torch.stack(vec_embeds, dim=-1).sum(dim=-1) - vis_embeds_tensor = torch.stack(vis_embeds, dim=-1).sum(dim=-1) - embedding = torch.stack([vec_embeds_tensor, vis_embeds_tensor], dim=-1).sum( - dim=-1 - ) - elif len(vec_embeds) > 0: - embedding = torch.stack(vec_embeds, dim=-1).sum(dim=-1) - elif len(vis_embeds) > 0: - embedding = torch.stack(vis_embeds, dim=-1).sum(dim=-1) - else: - raise Exception("No valid inputs to network.") - - if self.lstm is not None: - embedding = embedding.view([sequence_length, -1, self.h_size]) - memories_tensor = torch.split(memories, self.m_size // 2, dim=-1) - embedding, memories = self.lstm(embedding, memories_tensor) - embedding = embedding.view([-1, self.m_size // 2]) - memories = torch.cat(memories_tensor, dim=-1) - + embedding, memories = self.network_body( + vec_inputs, vis_inputs, actions, memories, sequence_length + ) output, _ = self.q_heads(embedding) return output, memories @@ -292,7 +231,7 @@ def get_dist_and_value( self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1 ): embedding, memories = self.network_body( - vec_inputs, vis_inputs, memories, sequence_length + vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length ) if self.act_type == ActionType.CONTINUOUS: dists = self.distribution(embedding) @@ -308,7 +247,7 @@ def forward( self, vec_inputs, vis_inputs=None, masks=None, memories=None, sequence_length=1 ): embedding, memories = self.network_body( - vec_inputs, vis_inputs, memories, sequence_length + vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length ) dists, value_outputs, memories = self.get_dist_and_value( vec_inputs, vis_inputs, masks, memories, sequence_length diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 3a363afa9a..70c96e1d0c 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -8,8 +8,8 @@ ResNetVisualEncoder, NatureVisualEncoder, VectorEncoder, + ActionVectorEncoder, ) -from mlagents.trainers.torch.encoders import Normalizer from mlagents.trainers.settings import EncoderType from mlagents.trainers.exception import UnityTrainerException @@ -56,8 +56,9 @@ def create_encoders( h_size: int, num_layers: int, vis_encode_type: EncoderType, - action_size: int = 0, - ) -> Tuple[nn.ModuleList, nn.ModuleList, nn.ModuleList]: + encoded_act_size: int = 0, + normalize: bool = False, + ) -> Tuple[nn.ModuleList, nn.ModuleList]: """ Creates visual and vector encoders, along with their normalizers. :param observation_shapes: List of Tuples that represent the action dimensions. @@ -70,7 +71,6 @@ def create_encoders( """ visual_encoders: List[nn.Module] = [] vector_encoders: List[nn.Module] = [] - vector_normalizers: List[nn.Module] = [] visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type) vector_size = 0 @@ -87,15 +87,17 @@ def create_encoders( raise UnityTrainerException( f"Unsupported shape of {dimension} for observation {i}" ) - vector_normalizers.append(Normalizer(vector_size)) - vector_encoders.append( - VectorEncoder(vector_size + action_size, h_size, num_layers) - ) - return ( - nn.ModuleList(visual_encoders), - nn.ModuleList(vector_encoders), - nn.ModuleList(vector_normalizers), - ) + if encoded_act_size > 0: + vector_encoders.append( + ActionVectorEncoder( + vector_size, h_size, encoded_act_size, num_layers, normalize + ) + ) + else: + vector_encoders.append( + VectorEncoder(vector_size, h_size, num_layers, normalize) + ) + return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders) @staticmethod def list_to_tensor( From c3b12f9528fd9f566e64c37c12724760a24a0646 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 27 Jul 2020 14:55:40 -0700 Subject: [PATCH 2/5] Unify Critic and ValueNetwork --- .../mlagents/trainers/sac/optimizer_torch.py | 26 +++++++--- ml-agents/mlagents/trainers/torch/networks.py | 49 ++++++------------- 2 files changed, 35 insertions(+), 40 deletions(-) diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index bde8b14b80..b5653a9f65 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -8,7 +8,7 @@ from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.settings import NetworkSettings -from mlagents.trainers.torch.networks import Critic, QNetwork +from mlagents.trainers.torch.networks import ValueNetwork from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.buffer import AgentBuffer from mlagents_envs.timers import timed @@ -31,11 +31,25 @@ def __init__( act_size: List[int], ): super().__init__() - self.q1_network = QNetwork( - stream_names, observation_shapes, network_settings, act_type, act_size + if act_type == ActionType.CONTINUOUS: + num_value_outs = 1 + num_action_ins = sum(act_size) + else: + num_value_outs = sum(act_size) + num_action_ins = 0 + self.q1_network = ValueNetwork( + stream_names, + observation_shapes, + network_settings, + num_action_ins, + num_value_outs, ) - self.q2_network = QNetwork( - stream_names, observation_shapes, network_settings, act_type, act_size + self.q2_network = ValueNetwork( + stream_names, + observation_shapes, + network_settings, + num_action_ins, + num_value_outs, ) def forward( @@ -86,7 +100,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): self.policy.behavior_spec.action_type, self.act_size, ) - self.target_network = Critic( + self.target_network = ValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, policy_network_settings, diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 0e472601a5..a9fb11b109 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -110,30 +110,26 @@ def forward( return embedding, memories -class QNetwork(nn.Module): - def __init__( # pylint: disable=W0231 +class ValueNetwork(nn.Module): + def __init__( self, stream_names: List[str], observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, - act_type: ActionType, - act_size: List[int], + encoded_act_size: int = 0, + outputs_per_stream: int = 1, ): # This is not a typo, we want to call __init__ of nn.Module nn.Module.__init__(self) - if act_type == ActionType.DISCRETE: - self.network_body = NetworkBody(observation_shapes, network_settings) - self.q_heads = ValueHeads( - stream_names, network_settings.hidden_units, sum(act_size) - ) - else: - self.network_body = NetworkBody( - observation_shapes, network_settings, encoded_act_size=sum(act_size) - ) - self.q_heads = ValueHeads(stream_names, network_settings.hidden_units) + self.network_body = NetworkBody( + observation_shapes, network_settings, encoded_act_size=encoded_act_size + ) + self.value_heads = ValueHeads( + stream_names, network_settings.hidden_units, outputs_per_stream + ) - def forward( # pylint: disable=W0221 + def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], @@ -144,7 +140,7 @@ def forward( # pylint: disable=W0221 embedding, memories = self.network_body( vec_inputs, vis_inputs, actions, memories, sequence_length ) - output, _ = self.q_heads(embedding) + output, _ = self.value_heads(embedding) return output, memories @@ -183,7 +179,9 @@ def __init__( else: self.distribution = MultiCategoricalDistribution(embedding_size, act_size) if separate_critic: - self.critic = Critic(stream_names, observation_shapes, network_settings) + self.critic = ValueNetwork( + stream_names, observation_shapes, network_settings + ) else: self.stream_names = stream_names self.value_heads = ValueHeads(stream_names, embedding_size) @@ -264,23 +262,6 @@ def forward( ) -class Critic(nn.Module): - def __init__( - self, - stream_names: List[str], - observation_shapes: List[Tuple[int, ...]], - network_settings: NetworkSettings, - ): - super().__init__() - self.network_body = NetworkBody(observation_shapes, network_settings) - self.stream_names = stream_names - self.value_heads = ValueHeads(stream_names, network_settings.hidden_units) - - def forward(self, vec_inputs, vis_inputs): - embedding, _ = self.network_body(vec_inputs, vis_inputs) - return self.value_heads(embedding) - - class GlobalSteps(nn.Module): def __init__(self): super().__init__() From b59330414b4fdf2b4501c6bc96f377e0f71165a1 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 29 Jul 2020 10:01:27 -0700 Subject: [PATCH 3/5] Rename ActionVectorEncoder --- ml-agents/mlagents/trainers/torch/encoders.py | 17 ++++++++++------- ml-agents/mlagents/trainers/torch/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index ee56f2b256..85c008f300 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -96,17 +96,20 @@ def update_normalization(self, inputs: torch.Tensor) -> None: self.normalizer.update(inputs) -class ActionVectorEncoder(VectorEncoder): +class VectorAndUnnormalizedInputEncoder(VectorEncoder): def __init__( self, input_size: int, hidden_size: int, - num_actions: int, + unnormalized_input_size: int, num_layers: int, normalize: bool = False, ): super().__init__( - input_size + num_actions, hidden_size, num_layers, normalize=False + input_size + unnormalized_input_size, + hidden_size, + num_layers, + normalize=False, ) if normalize: self.normalizer = Normalizer(input_size) @@ -114,15 +117,15 @@ def __init__( self.normalizer = None def forward( # pylint: disable=W0221 - self, inputs: torch.Tensor, actions: Optional[torch.Tensor] = None + self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None ) -> None: - if actions is None: + if unnormalized_inputs is None: raise UnityTrainerException( - "Attempted to call an ActionVectorEncoder without an action." + "Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input." ) # Fix mypy errors about method parameters. if self.normalizer is not None: inputs = self.normalizer(inputs) - return self.seq_layers(torch.cat([inputs, actions], dim=-1)) + return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1)) class SimpleVisualEncoder(nn.Module): diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 70c96e1d0c..65d5ac1dce 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -8,7 +8,7 @@ ResNetVisualEncoder, NatureVisualEncoder, VectorEncoder, - ActionVectorEncoder, + VectorAndUnnormalizedInputEncoder, ) from mlagents.trainers.settings import EncoderType from mlagents.trainers.exception import UnityTrainerException @@ -89,7 +89,7 @@ def create_encoders( ) if encoded_act_size > 0: vector_encoders.append( - ActionVectorEncoder( + VectorAndUnnormalizedInputEncoder( vector_size, h_size, encoded_act_size, num_layers, normalize ) ) From 64a0cf67e44d8c674d6aa1d1ed2f1a75c4f71408 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 29 Jul 2020 10:21:30 -0700 Subject: [PATCH 4/5] Update docstring of create_encoders --- ml-agents/mlagents/trainers/torch/networks.py | 2 +- ml-agents/mlagents/trainers/torch/utils.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index a9fb11b109..d44f99cbec 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -42,7 +42,7 @@ def __init__( self.h_size, network_settings.num_layers, network_settings.vis_encode_type, - encoded_act_size=encoded_act_size, + unnormalized_inputs=encoded_act_size, normalize=self.normalize, ) diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 65d5ac1dce..d628ac3ebf 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -56,7 +56,7 @@ def create_encoders( h_size: int, num_layers: int, vis_encode_type: EncoderType, - encoded_act_size: int = 0, + unnormalized_inputs: int = 0, normalize: bool = False, ) -> Tuple[nn.ModuleList, nn.ModuleList]: """ @@ -67,7 +67,10 @@ def create_encoders( :param h_size: Number of hidden units per layer. :param num_layers: Depth of MLP per encoder. :param vis_encode_type: Type of visual encoder to use. - :return: Tuple of visual encoders, vector encoders, and vector normalizers, each as a list. + :param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector + obs. + :param normalize: Normalize all vector inputs. + :return: Tuple of visual encoders and vector encoders each as a list. """ visual_encoders: List[nn.Module] = [] vector_encoders: List[nn.Module] = [] @@ -87,10 +90,10 @@ def create_encoders( raise UnityTrainerException( f"Unsupported shape of {dimension} for observation {i}" ) - if encoded_act_size > 0: + if unnormalized_inputs > 0: vector_encoders.append( VectorAndUnnormalizedInputEncoder( - vector_size, h_size, encoded_act_size, num_layers, normalize + vector_size, h_size, unnormalized_inputs, num_layers, normalize ) ) else: From 04b0a7333fe776887622f0b5d3910e76a9d49148 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Wed, 29 Jul 2020 10:55:16 -0700 Subject: [PATCH 5/5] Add docstring to UnnormalizedInputEncoder --- ml-agents/mlagents/trainers/torch/encoders.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index 85c008f300..0607fbcca5 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -97,6 +97,21 @@ def update_normalization(self, inputs: torch.Tensor) -> None: class VectorAndUnnormalizedInputEncoder(VectorEncoder): + """ + Encoder for concatenated vector input (can be normalized) and unnormalized vector input. + This is used for passing inputs to the network that should not be normalized, such as + actions in the case of a Q function or task parameterizations. It will result in an encoder with + this structure: + ____________ ____________ ____________ + | Vector | | Normalize | | Fully | + | | --> | | --> | Connected | ___________ + |____________| |____________| | | | Output | + ____________ | | --> | | + |Unnormalized| | | |___________| + | Input | ---------------------> | | + |____________| |____________| + """ + def __init__( self, input_size: int,