Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions ml-agents/mlagents/trainers/torch/encoders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Optional
from typing import Tuple, Optional, Union

from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish
Expand Down Expand Up @@ -45,21 +45,44 @@ def copy_from(self, other_normalizer: "Normalizer") -> None:
self.running_variance.copy_(other_normalizer.running_variance.data)


def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
def conv_output_shape(
h_w: Tuple[int, int],
kernel_size: Union[int, Tuple[int, int]] = 1,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
) -> Tuple[int, int]:
"""
Calculates the output shape (height and width) of the output of a convolution layer.
kernel_size, stride, padding and dilation correspond to the inputs of the
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
:param h_w: The height and width of the input.
:param kernel_size: The size of the kernel of the convolution (can be an int or a
tuple [width, height])
:param stride: The stride of the convolution
:param padding: The padding of the convolution
:param dilation: The dilation of the convolution
"""
from math import floor

if type(kernel_size) is not tuple:
kernel_size = (kernel_size, kernel_size)
if not isinstance(kernel_size, tuple):
kernel_size = (int(kernel_size), int(kernel_size))
h = floor(
((h_w[0] + (2 * pad) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1
)
w = floor(
((h_w[1] + (2 * pad) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1
)
return h, w


def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]:
"""
Calculates the output shape (height and width) of the output of a max pooling layer.
kernel_size corresponds to the inputs of the
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html)
:param kernel_size: The size of the kernel of the convolution
"""
height = (h_w[0] - kernel_size) // 2 + 1
width = (h_w[1] - kernel_size) // 2 + 1
return height, width
Expand Down Expand Up @@ -227,6 +250,25 @@ def forward(self, visual_obs: torch.Tensor) -> None:
return hidden


class ResNetBlock(nn.Module):
def __init__(self, channel: int):
"""
Creates a ResNet Block.
:param channel: The number of channels in the input (and output) tensors of the
convolutions
"""
super().__init__()
self.layers = nn.Sequential(
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return input_tensor + self.layers(input_tensor)


class ResNetVisualEncoder(nn.Module):
def __init__(self, height, width, initial_channels, final_hidden):
super().__init__()
Expand All @@ -241,7 +283,7 @@ def __init__(self, height, width, initial_channels, final_hidden):
self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
height, width = pool_out_shape((height, width), 3)
for _ in range(n_blocks):
self.layers.append(self.make_block(channel))
self.layers.append(ResNetBlock(channel))
last_channel = channel
self.layers.append(Swish())
self.dense = linear_layer(
Expand All @@ -251,30 +293,10 @@ def __init__(self, height, width, initial_channels, final_hidden):
kernel_gain=1.0,
)

@staticmethod
def make_block(channel):
block_layers = [
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
Swish(),
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
]
return block_layers

@staticmethod
def forward_block(input_hidden, block_layers):
hidden = input_hidden
for layer in block_layers:
hidden = layer(hidden)
return hidden + input_hidden

def forward(self, visual_obs):
batch_size = visual_obs.shape[0]
hidden = visual_obs
for layer in self.layers:
if isinstance(layer, nn.Module):
hidden = layer(hidden)
elif isinstance(layer, list):
hidden = self.forward_block(hidden, layer)
hidden = layer(hidden)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much better 👍

before_out = hidden.view(batch_size, -1)
return torch.relu(self.dense(before_out))