From 990ad69692927afab313465d378e25e1db341929 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Thu, 6 Aug 2020 13:23:55 -0700 Subject: [PATCH 1/5] Layer initialization + swish as a layer --- .../trainers/tests/torch/test_layers.py | 18 ++++++++++ ml-agents/mlagents/trainers/torch/layers.py | 36 +++++++++++++++++++ ml-agents/mlagents/trainers/torch/utils.py | 5 --- 3 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 ml-agents/mlagents/trainers/tests/torch/test_layers.py create mode 100644 ml-agents/mlagents/trainers/torch/layers.py diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py new file mode 100644 index 0000000000..40d2004fd4 --- /dev/null +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -0,0 +1,18 @@ +import torch + +from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization + + +def test_swish(): + layer = Swish() + input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor)) + assert torch.all(torch.eq(layer(input_tensor), target_tensor)) + + +def test_initialization_layer(): + torch.manual_seed(0) + # Test Zero + layer = linear_layer(3, 4, Initialization.Zero, Initialization.Zero) + assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) + assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py new file mode 100644 index 0000000000..c0911f9951 --- /dev/null +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -0,0 +1,36 @@ +import torch +from enum import Enum + + +class Swish(torch.nn.Module): + def forward(self, data: torch.Tensor) -> torch.Tensor: + return torch.mul(data, torch.sigmoid(data)) + + +class Initialization(Enum): + Zero = 0 + XavierGlorotNormal = 1 + XavierGlorotUniform = 2 + KaimingHeNormal = 3 + KaimingHeUniform = 4 + + +_init_methods = { + Initialization.Zero: torch.zero_, + Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_, + Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_, + Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_, + Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_, +} + + +def linear_layer( + input_size: int, + output_size: int, + kernel_init: Initialization = Initialization.XavierGlorotUniform, + bias_init: Initialization = Initialization.Zero, +) -> torch.nn.Module: + layer = torch.nn.Linear(input_size, output_size) + _init_methods[kernel_init](layer.weight.data) + _init_methods[bias_init](layer.bias.data) + return layer diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index 5d815cea0b..ecf2cb31b7 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -24,11 +24,6 @@ class ModelUtils: EncoderType.RESNET: 15, } - @staticmethod - def swish(input_activation: torch.Tensor) -> torch.Tensor: - """Swish activation function. For more info: https://arxiv.org/abs/1710.05941""" - return torch.mul(input_activation, torch.sigmoid(input_activation)) - @staticmethod def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: ENCODER_FUNCTION_BY_TYPE = { From 84ada0de5865f0d8bc440a3173e7589ba249eeab Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Thu, 6 Aug 2020 16:03:03 -0700 Subject: [PATCH 2/5] integrating with the existing layers --- ml-agents/mlagents/trainers/torch/decoders.py | 3 +- .../mlagents/trainers/torch/distributions.py | 28 ++++++++++--- ml-agents/mlagents/trainers/torch/encoders.py | 40 ++++++++++++++++--- ml-agents/mlagents/trainers/torch/layers.py | 14 ++++++- 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/decoders.py b/ml-agents/mlagents/trainers/torch/decoders.py index c7f332796d..db54a6232f 100644 --- a/ml-agents/mlagents/trainers/torch/decoders.py +++ b/ml-agents/mlagents/trainers/torch/decoders.py @@ -2,6 +2,7 @@ import torch from torch import nn +from mlagents.trainers.torch.layers import linear_layer class ValueHeads(nn.Module): @@ -11,7 +12,7 @@ def __init__(self, stream_names: List[str], input_size: int, output_size: int = _value_heads = {} for name in stream_names: - value = nn.Linear(input_size, output_size) + value = linear_layer(input_size, output_size) _value_heads[name] = value self.value_heads = nn.ModuleDict(_value_heads) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index c83ae4649e..bf22ee7c5a 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -4,6 +4,7 @@ from torch import nn import numpy as np import math +from mlagents.trainers.torch.layers import linear_layer, Initialization EPSILON = 1e-7 # Small value to avoid divide by zero @@ -122,12 +123,22 @@ def __init__( ): super().__init__() self.conditional_sigma = conditional_sigma - self.mu = nn.Linear(hidden_size, num_outputs) + self.mu = linear_layer( + hidden_size, + num_outputs, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) self.tanh_squash = tanh_squash - nn.init.xavier_uniform_(self.mu.weight, gain=0.01) if conditional_sigma: - self.log_sigma = nn.Linear(hidden_size, num_outputs) - nn.init.xavier_uniform(self.log_sigma.weight, gain=0.01) + self.log_sigma = linear_layer( + hidden_size, + num_outputs, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) else: self.log_sigma = nn.Parameter( torch.zeros(1, num_outputs, requires_grad=True) @@ -154,8 +165,13 @@ def __init__(self, hidden_size: int, act_sizes: List[int]): def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList: branches = [] for size in self.act_sizes: - branch_output_layer = nn.Linear(hidden_size, size) - nn.init.xavier_uniform_(branch_output_layer.weight, gain=0.01) + branch_output_layer = linear_layer( + hidden_size, + size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=0.1, + bias_init=Initialization.Zero, + ) branches.append(branch_output_layer) return nn.ModuleList(branches) diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index dd9543987e..a0d9ef236a 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -1,6 +1,7 @@ from typing import Tuple, Optional from mlagents.trainers.exception import UnityTrainerException +from mlagents.trainers.torch.layers import linear_layer, Initialization import torch from torch import nn @@ -74,12 +75,26 @@ def __init__( ): self.normalizer: Optional[Normalizer] = None super().__init__() - self.layers = [nn.Linear(input_size, hidden_size)] + self.layers = [ + linear_layer( + input_size, + hidden_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) + ] if normalize: self.normalizer = Normalizer(input_size) for _ in range(num_layers - 1): - self.layers.append(nn.Linear(hidden_size, hidden_size)) + self.layers.append( + linear_layer( + hidden_size, + hidden_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) + ) self.layers.append(nn.LeakyReLU()) self.seq_layers = nn.Sequential(*self.layers) @@ -156,7 +171,12 @@ def __init__( self.conv1 = nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]) self.conv2 = nn.Conv2d(16, 32, [4, 4], [2, 2]) - self.dense = nn.Linear(self.final_flat, self.h_size) + self.dense = linear_layer( + self.final_flat, + self.h_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) def forward(self, visual_obs: torch.Tensor) -> None: conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs)) @@ -180,7 +200,12 @@ def __init__(self, height, width, initial_channels, output_size): self.conv1 = nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]) self.conv2 = nn.Conv2d(32, 64, [4, 4], [2, 2]) self.conv3 = nn.Conv2d(64, 64, [3, 3], [1, 1]) - self.dense = nn.Linear(self.final_flat, self.h_size) + self.dense = linear_layer( + self.final_flat, + self.h_size, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) def forward(self, visual_obs): conv_1 = nn.functional.leaky_relu(self.conv1(visual_obs)) @@ -209,7 +234,12 @@ def __init__(self, height, width, initial_channels, final_hidden): self.layers.append(self.make_block(channel)) last_channel = channel self.layers.append(nn.LeakyReLU()) - self.dense = nn.Linear(n_channels[-1] * height * width, final_hidden) + self.dense = linear_layer( + n_channels[-1] * height * width, + final_hidden, + kernel_init=Initialization.KaimingHeNormal, + kernel_gain=1.0, + ) @staticmethod def make_block(channel): diff --git a/ml-agents/mlagents/trainers/torch/layers.py b/ml-agents/mlagents/trainers/torch/layers.py index c0911f9951..8dbb1cbcb4 100644 --- a/ml-agents/mlagents/trainers/torch/layers.py +++ b/ml-agents/mlagents/trainers/torch/layers.py @@ -11,7 +11,7 @@ class Initialization(Enum): Zero = 0 XavierGlorotNormal = 1 XavierGlorotUniform = 2 - KaimingHeNormal = 3 + KaimingHeNormal = 3 # also known as Variance scaling KaimingHeUniform = 4 @@ -28,9 +28,21 @@ def linear_layer( input_size: int, output_size: int, kernel_init: Initialization = Initialization.XavierGlorotUniform, + kernel_gain: float = 1.0, bias_init: Initialization = Initialization.Zero, ) -> torch.nn.Module: + """ + Creates a torch.nn.Linear module and initializes its weights. + :param input_size: The size of the input tensor + :param output_size: The size of the output tensor + :param kernel_init: The Initialization to use for the weights of the layer + :param kernel_gain: The multiplier for the weights of the kernel. Note that in + TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling + KaimingHeNormal with kernel_gain of 0.1 + :param bias_init: The Initialization to use for the weights of the bias layer + """ layer = torch.nn.Linear(input_size, output_size) _init_methods[kernel_init](layer.weight.data) + layer.weight.data *= kernel_gain _init_methods[bias_init](layer.bias.data) return layer From 19f4c050feaabb2e2bd71b1767001c0e5c8ef5c8 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Thu, 6 Aug 2020 17:00:19 -0700 Subject: [PATCH 3/5] fixing tests --- ml-agents/mlagents/trainers/tests/torch/test_layers.py | 4 +++- ml-agents/mlagents/trainers/tests/torch/test_networks.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_layers.py b/ml-agents/mlagents/trainers/tests/torch/test_layers.py index 40d2004fd4..499d0de285 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_layers.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_layers.py @@ -13,6 +13,8 @@ def test_swish(): def test_initialization_layer(): torch.manual_seed(0) # Test Zero - layer = linear_layer(3, 4, Initialization.Zero, Initialization.Zero) + layer = linear_layer( + 3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero + ) assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index ff5209b676..974a3d6bd1 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -26,7 +26,7 @@ def test_networkbody_vector(): sample_obs = torch.ones((1, obs_size)) sample_act = torch.ones((1, 2)) - for _ in range(100): + for _ in range(200): encoded, _ = networkbody([sample_obs], [], sample_act) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 @@ -77,7 +77,7 @@ def test_networkbody_visual(): sample_obs = torch.ones((1, 84, 84, 3)) sample_vec_obs = torch.ones((1, vec_obs_size)) - for _ in range(100): + for _ in range(150): encoded, _ = networkbody([sample_vec_obs], [sample_obs]) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 From b2228e6517cab6a4309e2c76cf6a99eed5d2d9c4 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Thu, 6 Aug 2020 19:48:24 -0700 Subject: [PATCH 4/5] setting the seed for a test --- ml-agents/mlagents/trainers/tests/torch/test_networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index 974a3d6bd1..d7ecc3743c 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -17,6 +17,7 @@ def test_networkbody_vector(): + torch.manual_seed(0) obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size,)] @@ -26,7 +27,7 @@ def test_networkbody_vector(): sample_obs = torch.ones((1, obs_size)) sample_act = torch.ones((1, 2)) - for _ in range(200): + for _ in range(300): encoded, _ = networkbody([sample_obs], [], sample_act) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 From ca9667ae1cc8836c0cc7b05b228aa3b4d1539ed7 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Fri, 7 Aug 2020 12:13:26 -0700 Subject: [PATCH 5/5] Using swish and fixing tests --- .../mlagents/trainers/tests/torch/test_networks.py | 4 ++-- ml-agents/mlagents/trainers/torch/encoders.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index d7ecc3743c..5d3cf9a9e5 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -24,8 +24,8 @@ def test_networkbody_vector(): networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) - sample_obs = torch.ones((1, obs_size)) - sample_act = torch.ones((1, 2)) + sample_obs = 0.1 * torch.ones((1, obs_size)) + sample_act = 0.1 * torch.ones((1, 2)) for _ in range(300): encoded, _ = networkbody([sample_obs], [], sample_act) diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index a0d9ef236a..af18f73ea7 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -1,7 +1,7 @@ from typing import Tuple, Optional from mlagents.trainers.exception import UnityTrainerException -from mlagents.trainers.torch.layers import linear_layer, Initialization +from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish import torch from torch import nn @@ -95,7 +95,7 @@ def __init__( kernel_gain=1.0, ) ) - self.layers.append(nn.LeakyReLU()) + self.layers.append(Swish()) self.seq_layers = nn.Sequential(*self.layers) def forward(self, inputs: torch.Tensor) -> None: @@ -233,7 +233,7 @@ def __init__(self, height, width, initial_channels, final_hidden): for _ in range(n_blocks): self.layers.append(self.make_block(channel)) last_channel = channel - self.layers.append(nn.LeakyReLU()) + self.layers.append(Swish()) self.dense = linear_layer( n_channels[-1] * height * width, final_hidden, @@ -244,9 +244,9 @@ def __init__(self, height, width, initial_channels, final_hidden): @staticmethod def make_block(channel): block_layers = [ - nn.LeakyReLU(), + Swish(), nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), - nn.LeakyReLU(), + Swish(), nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), ] return block_layers