diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index fc5cae71d8..982d0ba8b7 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from mlagents.trainers.exception import UnityTrainerException from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish @@ -45,21 +45,44 @@ def copy_from(self, other_normalizer: "Normalizer") -> None: self.running_variance.copy_(other_normalizer.running_variance.data) -def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): +def conv_output_shape( + h_w: Tuple[int, int], + kernel_size: Union[int, Tuple[int, int]] = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[int, int]: + """ + Calculates the output shape (height and width) of the output of a convolution layer. + kernel_size, stride, padding and dilation correspond to the inputs of the + torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) + :param h_w: The height and width of the input. + :param kernel_size: The size of the kernel of the convolution (can be an int or a + tuple [width, height]) + :param stride: The stride of the convolution + :param padding: The padding of the convolution + :param dilation: The dilation of the convolution + """ from math import floor - if type(kernel_size) is not tuple: - kernel_size = (kernel_size, kernel_size) + if not isinstance(kernel_size, tuple): + kernel_size = (int(kernel_size), int(kernel_size)) h = floor( - ((h_w[0] + (2 * pad) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 + ((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 ) w = floor( - ((h_w[1] + (2 * pad) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 + ((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 ) return h, w def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: + """ + Calculates the output shape (height and width) of the output of a max pooling layer. + kernel_size corresponds to the inputs of the + torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) + :param kernel_size: The size of the kernel of the convolution + """ height = (h_w[0] - kernel_size) // 2 + 1 width = (h_w[1] - kernel_size) // 2 + 1 return height, width @@ -227,6 +250,25 @@ def forward(self, visual_obs: torch.Tensor) -> None: return hidden +class ResNetBlock(nn.Module): + def __init__(self, channel: int): + """ + Creates a ResNet Block. + :param channel: The number of channels in the input (and output) tensors of the + convolutions + """ + super().__init__() + self.layers = nn.Sequential( + Swish(), + nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), + Swish(), + nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + return input_tensor + self.layers(input_tensor) + + class ResNetVisualEncoder(nn.Module): def __init__(self, height, width, initial_channels, final_hidden): super().__init__() @@ -241,7 +283,7 @@ def __init__(self, height, width, initial_channels, final_hidden): self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) height, width = pool_out_shape((height, width), 3) for _ in range(n_blocks): - self.layers.append(self.make_block(channel)) + self.layers.append(ResNetBlock(channel)) last_channel = channel self.layers.append(Swish()) self.dense = linear_layer( @@ -251,30 +293,10 @@ def __init__(self, height, width, initial_channels, final_hidden): kernel_gain=1.0, ) - @staticmethod - def make_block(channel): - block_layers = [ - Swish(), - nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), - Swish(), - nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), - ] - return block_layers - - @staticmethod - def forward_block(input_hidden, block_layers): - hidden = input_hidden - for layer in block_layers: - hidden = layer(hidden) - return hidden + input_hidden - def forward(self, visual_obs): batch_size = visual_obs.shape[0] hidden = visual_obs for layer in self.layers: - if isinstance(layer, nn.Module): - hidden = layer(hidden) - elif isinstance(layer, list): - hidden = self.forward_block(hidden, layer) + hidden = layer(hidden) before_out = hidden.view(batch_size, -1) return torch.relu(self.dense(before_out))