diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py new file mode 100644 index 0000000000..499d0de285 --- /dev/null +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -0,0 +1,20 @@ +import torch + +from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization + + +def test_swish(): + layer = Swish() + input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor)) + assert torch.all(torch.eq(layer(input_tensor), target_tensor)) + + +def test_initialization_layer(): + torch.manual_seed(0) + # Test Zero + layer = linear_layer( + 3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero + ) + assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) + assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index ff5209b676..5d3cf9a9e5 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -17,16 +17,17 @@ def test_networkbody_vector(): + torch.manual_seed(0) obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size,)] networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) - sample_obs = torch.ones((1, obs_size)) - sample_act = torch.ones((1, 2)) + sample_obs = 0.1 * torch.ones((1, obs_size)) + sample_act = 0.1 * torch.ones((1, 2)) - for _ in range(100): + for _ in range(300): encoded, _ = networkbody([sample_obs], [], sample_act) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 @@ -77,7 +78,7 @@ def test_networkbody_visual(): sample_obs = torch.ones((1, 84, 84, 3)) sample_vec_obs = torch.ones((1, vec_obs_size)) - for _ in range(100): + for _ in range(150): encoded, _ = networkbody([sample_vec_obs], [sample_obs]) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py index 81c30ef82f..842b039510 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py @@ -11,6 +11,7 @@ from mlagents_envs.base_env import BehaviorSpec from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.networks import NetworkBody +from mlagents.trainers.torch.layers import linear_layer, Swish from mlagents.trainers.settings import NetworkSettings, EncoderType @@ -70,22 +71,18 @@ def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: self._action_flattener = ModelUtils.ActionFlattener(specs) self.inverse_model_action_predition = torch.nn.Sequential( - torch.nn.Linear(2 * settings.encoding_size, 256), - ModelUtils.SwishLayer(), - torch.nn.Linear(256, self._action_flattener.flattened_size), + linear_layer(2 * settings.encoding_size, 256), + Swish(), + linear_layer(256, self._action_flattener.flattened_size), ) - self.inverse_model_action_predition[0].bias.data.zero_() - self.inverse_model_action_predition[2].bias.data.zero_() self.forward_model_next_state_prediction = torch.nn.Sequential( - torch.nn.Linear( + linear_layer( settings.encoding_size + self._action_flattener.flattened_size, 256 ), - ModelUtils.SwishLayer(), - torch.nn.Linear(256, settings.encoding_size), + Swish(), + linear_layer(256, settings.encoding_size), ) - self.forward_model_next_state_prediction[0].bias.data.zero_() - self.forward_model_next_state_prediction[2].bias.data.zero_() def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py index f3684a338c..dd3a9854c4 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py @@ -10,6 +10,7 @@ from mlagents_envs.base_env import BehaviorSpec from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.torch.networks import NetworkBody +from mlagents.trainers.torch.layers import linear_layer, Swish, Initialization from mlagents.trainers.settings import NetworkSettings, EncoderType from mlagents.trainers.demo_loader import demo_to_buffer @@ -98,15 +99,11 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: ) # + 1 is for done self.encoder = torch.nn.Sequential( - torch.nn.Linear(encoder_input_size, settings.encoding_size), - ModelUtils.SwishLayer(), - torch.nn.Linear(settings.encoding_size, settings.encoding_size), - ModelUtils.SwishLayer(), + linear_layer(encoder_input_size, settings.encoding_size), + Swish(), + linear_layer(settings.encoding_size, settings.encoding_size), + Swish(), ) - torch.nn.init.xavier_normal_(self.encoder[0].weight.data) - torch.nn.init.xavier_normal_(self.encoder[2].weight.data) - self.encoder[0].bias.data.zero_() - self.encoder[2].bias.data.zero_() estimator_input_size = settings.encoding_size if settings.use_vail: @@ -114,19 +111,19 @@ def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: self.z_sigma = torch.nn.Parameter( torch.ones((self.z_size), dtype=torch.float), requires_grad=True ) - self.z_mu_layer = torch.nn.Linear(settings.encoding_size, self.z_size) - # self.z_mu_layer.weight.data Needs a variance scale initializer - torch.nn.init.xavier_normal_(self.z_mu_layer.weight.data) - self.z_mu_layer.bias.data.zero_() + self.z_mu_layer = linear_layer( + settings.encoding_size, + self.z_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + ) self.beta = torch.nn.Parameter( torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False ) self.estimator = torch.nn.Sequential( - torch.nn.Linear(estimator_input_size, 1), torch.nn.Sigmoid() + linear_layer(estimator_input_size, 1), torch.nn.Sigmoid() ) - torch.nn.init.xavier_normal_(self.estimator[0].weight.data) - self.estimator[0].bias.data.zero_() def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: """ diff --git a/ml-agents/mlagents/trainers/torch/decoders.py b/ml-agents/mlagents/trainers/torch/decoders.py index c7f332796d..db54a6232f 100644 --- a/ml-agents/mlagents/trainers/torch/decoders.py +++ b/ml-agents/mlagents/trainers/torch/decoders.py @@ -2,6 +2,7 @@ import torch from torch import nn +from mlagents.trainers.torch.layers import linear_layer class ValueHeads(nn.Module): @@ -11,7 +12,7 @@ def __init__(self, stream_names: List[str], input_size: int, output_size: int = _value_heads = {} for name in stream_names: - value = nn.Linear(input_size, output_size) + value = linear_layer(input_size, output_size) _value_heads[name] = value self.value_heads = nn.ModuleDict(_value_heads) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 570460e36d..dc3ce3e44a 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -4,6 +4,7 @@ from torch import nn import numpy as np import math +from mlagents.trainers.torch.layers import linear_layer, Initialization EPSILON = 1e-7 # Small value to avoid divide by zero @@ -127,12 +128,22 @@ def __init__( ): super().__init__() self.conditional_sigma = conditional_sigma - self.mu = nn.Linear(hidden_size, num_outputs) + self.mu = linear_layer( + hidden_size, + num_outputs, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) self.tanh_squash = tanh_squash - nn.init.xavier_uniform_(self.mu.weight, gain=0.01) if conditional_sigma: - self.log_sigma = nn.Linear(hidden_size, num_outputs) - nn.init.xavier_uniform(self.log_sigma.weight, gain=0.01) + self.log_sigma = linear_layer( + hidden_size, + num_outputs, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) else: self.log_sigma = nn.Parameter( torch.zeros(1, num_outputs, requires_grad=True) @@ -159,8 +170,13 @@ def __init__(self, hidden_size: int, act_sizes: List[int]): def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList: branches = [] for size in self.act_sizes: - branch_output_layer = nn.Linear(hidden_size, size) - nn.init.xavier_uniform_(branch_output_layer.weight, gain=0.01) + branch_output_layer = linear_layer( + hidden_size, + size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) branches.append(branch_output_layer) return nn.ModuleList(branches) diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index 697ed926fd..fc5cae71d8 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -1,6 +1,7 @@ from typing import Tuple, Optional from mlagents.trainers.exception import UnityTrainerException +from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish import torch from torch import nn @@ -64,11 +65,6 @@ def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: return height, width -class SwishLayer(torch.nn.Module): - def forward(self, data: torch.Tensor) -> torch.Tensor: - return torch.mul(data, torch.sigmoid(data)) - - class VectorEncoder(nn.Module): def __init__( self, @@ -79,14 +75,28 @@ def __init__( ): self.normalizer: Optional[Normalizer] = None super().__init__() - self.layers = [nn.Linear(input_size, hidden_size)] - self.layers.append(SwishLayer()) + self.layers = [ + linear_layer( + input_size, + hidden_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) + ] + self.layers.append(Swish()) if normalize: self.normalizer = Normalizer(input_size) for _ in range(num_layers - 1): - self.layers.append(nn.Linear(hidden_size, hidden_size)) - self.layers.append(nn.LeakyReLU()) + self.layers.append( + linear_layer( + hidden_size, + hidden_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) + ) + self.layers.append(Swish()) self.seq_layers = nn.Sequential(*self.layers) def forward(self, inputs: torch.Tensor) -> None: @@ -160,17 +170,26 @@ def __init__( conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32 - self.conv1 = nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]) - self.conv2 = nn.Conv2d(16, 32, [4, 4], [2, 2]) - self.dense = nn.Linear(self.final_flat, self.h_size) + self.conv_layers = nn.Sequential( + nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]), + nn.LeakyReLU(), + nn.Conv2d(16, 32, [4, 4], [2, 2]), + nn.LeakyReLU(), + ) + self.dense = nn.Sequential( + linear_layer( + self.final_flat, + self.h_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ), + nn.LeakyReLU(), + ) def forward(self, visual_obs: torch.Tensor) -> None: - conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs)) - conv_2 = nn.functional.leaky_relu(self.conv2(conv_1)) - # hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat]))) - hidden = nn.functional.leaky_relu( - self.dense(torch.reshape(conv_2, (-1, self.final_flat))) - ) + hidden = self.conv_layers(visual_obs) + hidden = torch.reshape(hidden, (-1, self.final_flat)) + hidden = self.dense(hidden) return hidden @@ -183,18 +202,28 @@ def __init__(self, height, width, initial_channels, output_size): conv_3_hw = conv_output_shape(conv_2_hw, 3, 1) self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64 - self.conv1 = nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]) - self.conv2 = nn.Conv2d(32, 64, [4, 4], [2, 2]) - self.conv3 = nn.Conv2d(64, 64, [3, 3], [1, 1]) - self.dense = nn.Linear(self.final_flat, self.h_size) - - def forward(self, visual_obs): - conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs)) - conv_2 = nn.functional.leaky_relu(self.conv2(conv_1)) - conv_3 = nn.functional.leaky_relu(self.conv3(conv_2)) - hidden = nn.functional.leaky_relu( - self.dense(conv_3.view([-1, self.final_flat])) + self.conv_layers = nn.Sequential( + nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]), + nn.LeakyReLU(), + nn.Conv2d(32, 64, [4, 4], [2, 2]), + nn.LeakyReLU(), + nn.Conv2d(64, 64, [3, 3], [1, 1]), + nn.LeakyReLU(), + ) + self.dense = nn.Sequential( + linear_layer( + self.final_flat, + self.h_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ), + nn.LeakyReLU(), ) + + def forward(self, visual_obs: torch.Tensor) -> None: + hidden = self.conv_layers(visual_obs) + hidden = hidden.view([-1, self.final_flat]) + hidden = self.dense(hidden) return hidden @@ -214,15 +243,20 @@ def __init__(self, height, width, initial_channels, final_hidden): for _ in range(n_blocks): self.layers.append(self.make_block(channel)) last_channel = channel - self.layers.append(nn.LeakyReLU()) - self.dense = nn.Linear(n_channels[-1] * height * width, final_hidden) + self.layers.append(Swish()) + self.dense = linear_layer( + n_channels[-1] * height * width, + final_hidden, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) @staticmethod def make_block(channel): block_layers = [ - nn.LeakyReLU(), + Swish(), nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), - nn.LeakyReLU(), + Swish(), nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), ] return block_layers diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py new file mode 100644 index 0000000000..8dbb1cbcb4 --- /dev/null +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -0,0 +1,48 @@ +import torch +from enum import Enum + + +class Swish(torch.nn.Module): + def forward(self, data: torch.Tensor) -> torch.Tensor: + return torch.mul(data, torch.sigmoid(data)) + + +class Initialization(Enum): + Zero = 0 + XavierGlorotNormal = 1 + XavierGlorotUniform = 2 + KaimingHeNormal = 3 # also known as Variance scaling + KaimingHeUniform = 4 + + +_init_methods = { + Initialization.Zero: torch.zero_, + Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_, + Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_, + Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_, + Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_, +} + + +def linear_layer( + input_size: int, + output_size: int, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + kernel_gain: float = 1.0, + bias_init: Initialization = Initialization.Zero, +) -> torch.nn.Module: + """ + Creates a torch.nn.Linear module and initializes its weights. + :param input_size: The size of the input tensor + :param output_size: The size of the output tensor + :param kernel_init: The Initialization to use for the weights of the layer + :param kernel_gain: The multiplier for the weights of the kernel. Note that in + TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling + KaimingHeNormal with kernel_gain of 0.1 + :param bias_init: The Initialization to use for the weights of the bias layer + """ + layer = torch.nn.Linear(input_size, output_size) + _init_methods[kernel_init](layer.weight.data) + layer.weight.data *= kernel_gain + _init_methods[bias_init](layer.bias.data) + return layer diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 5388df1630..baa99c98ef 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -25,15 +25,6 @@ class ModelUtils: EncoderType.RESNET: 15, } - @staticmethod - def swish(input_activation: torch.Tensor) -> torch.Tensor: - """Swish activation function. For more info: https://arxiv.org/abs/1710.05941""" - return torch.mul(input_activation, torch.sigmoid(input_activation)) - - class SwishLayer(torch.nn.Module): - def forward(self, data: torch.Tensor) -> torch.Tensor: - return torch.mul(data, torch.sigmoid(data)) - class ActionFlattener: def __init__(self, behavior_spec: BehaviorSpec): self._specs = behavior_spec