In [1]:
import torch

In [21]:

class ResidualBlock(torch.nn.Module):
    """A Conv2D block with skip connections.

    A single ResidualBlock module computes the following:
        `y = relu(x + norm(conv(relu(norm(conv(x))))))`
    where `x` is the input, `y` is the output, `norm` is a 2D batch norm and `conv` is
    a 2D convolution with kernel of size 3, stride 1 and padding 1.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """Init the ResidualBlock.

        Args:
            in_channels: Number of input channels.
            out_channels: Number of output channels.
        """
        super().__init__()


        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm1 = torch.nn.BatchNorm2d(num_features=out_channels)
        self.norm2 = torch.nn.BatchNorm2d(num_features=out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute the forward pass of the ResidualBlock.

        Args:
            x: The input.

        Returns:
            The output after applying the residual block. See the class description
            for more details.
        """
        y = self.conv1(x)
        y = self.norm1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.norm2(y)
        print(y.shape, x.shape)
        return self.relu(y + x)

In [22]:
class Conv2DEncoder(torch.nn.Module):
    """An image encoder based on 2D convolutions.

    Based on: https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
    """

    def __init__(
        self,
        *,
        input_channels: int = 3,
        output_channels: int = 512,
        hidden_size: int = 64,
        num_layers: int = 3,
        num_resnet_blocks: int = 1,
    ):
        """Init the encoder.

        Args:
            input_channels: Number of input channels.
            output_channels: Number of the output channels.
            hidden_size: Number of channels in the intermediate hidden layers.
            num_layers: Number of hidden layers.
            num_resnet_blocks: Number of resnet blocks added after each layer.
        """
        super().__init__()
        layers_list: list[torch.nn.Module] = [
            torch.nn.Conv2d(input_channels, hidden_size, kernel_size=1)
        ]
        for _ in range(num_layers):
            layers_list.extend(
                [ResidualBlock(hidden_size, hidden_size) for _ in range(num_resnet_blocks)]
            )
            layers_list.append(
                torch.nn.Conv2d(hidden_size, hidden_size, kernel_size=4, stride=2, padding=1)
            )
            layers_list.append(torch.nn.ReLU())

        layers_list.append(torch.nn.Conv2d(hidden_size, output_channels, kernel_size=1))
        self.layers = torch.nn.Sequential(*layers_list)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Encode an image.

        Args:
            x: The input image of shape `(batch, input_channels, in_width, in_height)`

        Returns:
            The encoder image of shape `(batch, output_channels, out_width, out_height)`
        """
        return self.layers(x)

In [23]:
enc_model = Conv2DEncoder()

In [24]:
print(enc_model)

Conv2DEncoder(
  (layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (2): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (5

In [25]:
x = torch.randn(1, 3, 256, 256)
y = enc_model(x)

torch.Size([1, 64, 256, 256]) torch.Size([1, 64, 256, 256])
torch.Size([1, 64, 128, 128]) torch.Size([1, 64, 128, 128])
torch.Size([1, 64, 64, 64]) torch.Size([1, 64, 64, 64])


In [14]:
y.shape

torch.Size([1, 512, 32, 32])