In [1]:
import torch as th
import torch.nn as nn
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
encoder = nn.Sequential(
    nn.Conv2d(
        in_channels=3,
        out_channels=16,
        kernel_size=3,
        stride=2
    ),
    nn.ReLU(True),

    nn.Conv2d(
        in_channels=16,
        out_channels=32,
        kernel_size=3,
        stride=2
    ),
    nn.ReLU(True),

    nn.Conv2d(
        in_channels=32,
        out_channels=64,
        kernel_size=4,
        stride=2
    ),
    nn.ReLU(True),

    nn.Conv2d(
        in_channels=64,
        out_channels=128,
        kernel_size=4,
        stride=2
    ),

    nn.Flatten(),
    nn.Linear(128 * 6 * 6, 512),
    nn.ReLU(True)

) ; encoder
print(encoder)
print(summary(encoder))

x = th.zeros(1, 3, 128, 128)

encoder(x).shape

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2))
  (1): ReLU(inplace=True)
  (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
  (3): ReLU(inplace=True)
  (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
  (7): Flatten(start_dim=1, end_dim=-1)
  (8): Linear(in_features=4608, out_features=512, bias=True)
  (9): ReLU(inplace=True)
)
Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            448
├─ReLU: 1-2                              --
├─Conv2d: 1-3                            4,640
├─ReLU: 1-4                              --
├─Conv2d: 1-5                            32,832
├─ReLU: 1-6                              --
├─Conv2d: 1-7                            131,200
├─Flatten: 1-8                           --
├─Linear: 1-9                            2,359,808
├─ReLU: 1-10                             --


torch.Size([1, 512])

In [3]:
class ReshapeLayer(nn.Module):
    def __init__(self, output_shape):
        super().__init__()
        self.output_shape = output_shape
    
    def forward(self, x):
        B = x.shape[0]
        return x.reshape(B, *self.output_shape)

decoder = nn.Sequential(
    nn.Linear(1024, 64 * 6 * 6),
    nn.ReLU(True),
    ReshapeLayer([64, 6, 6]),

    nn.ConvTranspose2d(
        in_channels=64,
        out_channels=64,
        kernel_size=3,
        stride=2,
        output_padding=1
    ),
    nn.ReLU(True),

    nn.ConvTranspose2d(
        in_channels=64,
        out_channels=32,
        kernel_size=4,
        stride=2,
        output_padding=1
    ),
    nn.ReLU(True),

    nn.ConvTranspose2d(
        in_channels=32,
        out_channels=3,
        kernel_size=4,
        stride=2
    )
) ; decoder
print(decoder)
print(summary(decoder))

x = th.zeros([1, 1024])

decoder(x).shape

Sequential(
  (0): Linear(in_features=1024, out_features=2304, bias=True)
  (1): ReLU(inplace=True)
  (2): ReshapeLayer()
  (3): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2))
)
Layer (type:depth-idx)                   Param #
Sequential                               --
├─Linear: 1-1                            2,361,600
├─ReLU: 1-2                              --
├─ReshapeLayer: 1-3                      --
├─ConvTranspose2d: 1-4                   36,928
├─ReLU: 1-5                              --
├─ConvTranspose2d: 1-6                   32,800
├─ReLU: 1-7                              --
├─ConvTranspose2d: 1-8                   1,539
Total params: 2,432,867
Trainable params: 2,432,867
Non-trainable params: 0


torch.Size([1, 3, 64, 64])

## AE with a larger latent that is used for reconstruction, while 512 is still passsed to the RNN