In [1]:
import torch as th
import torch.nn as nn
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
encoder = nn.Sequential(
    nn.Conv2d(
        in_channels=3,
        out_channels=48,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.Conv2d(
        in_channels=48,
        out_channels=96,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.Conv2d(
        in_channels=96,
        out_channels=192,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.Conv2d(
        in_channels=192,
        out_channels=384,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.Conv2d(
        in_channels=384,
        out_channels=384,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.Flatten(),
    nn.Linear(1536, 512)

) ; encoder
# print(encoder)
print(summary(encoder))

x = th.zeros(1, 3, 128, 128)

encoder(x).shape

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            2,352
├─ELU: 1-2                               --
├─Conv2d: 1-3                            73,824
├─ELU: 1-4                               --
├─Conv2d: 1-5                            295,104
├─ELU: 1-6                               --
├─Conv2d: 1-7                            1,180,032
├─ELU: 1-8                               --
├─Conv2d: 1-9                            2,359,680
├─ELU: 1-10                              --
├─Flatten: 1-11                          --
├─Linear: 1-12                           786,944
Total params: 4,697,936
Trainable params: 4,697,936
Non-trainable params: 0


torch.Size([1, 512])

In [42]:
class ReshapeLayer(nn.Module):
    def __init__(self, output_shape):
        super().__init__()
        self.output_shape = output_shape
    
    def forward(self, x):
        B = x.shape[0]
        return x.reshape(B, *self.output_shape)

decoder = nn.Sequential(
    nn.Linear(512, 1536),
    nn.ELU(True),
    ReshapeLayer([1536, 1, 1]),

    nn.ConvTranspose2d(
        in_channels=1536,
        out_channels=192,
        kernel_size=5,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=192,
        out_channels=96,
        kernel_size=5,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=96,
        out_channels=48,
        kernel_size=6,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=48,
        out_channels=16,
        kernel_size=6,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=16,
        out_channels=3,
        kernel_size=2,
        stride=2
    )

) ; decoder
print(decoder)
print(summary(decoder))

x = th.zeros([1, 512])

decoder(x).shape

Sequential(
  (0): Linear(in_features=512, out_features=1536, bias=True)
  (1): ELU(alpha=True)
  (2): ReshapeLayer()
  (3): ConvTranspose2d(1536, 192, kernel_size=(5, 5), stride=(2, 2))
  (4): ELU(alpha=True)
  (5): ConvTranspose2d(192, 96, kernel_size=(5, 5), stride=(2, 2))
  (6): ELU(alpha=True)
  (7): ConvTranspose2d(96, 48, kernel_size=(6, 6), stride=(2, 2))
  (8): ELU(alpha=True)
  (9): ConvTranspose2d(48, 16, kernel_size=(6, 6), stride=(2, 2))
  (10): ELU(alpha=True)
  (11): ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2))
)
Layer (type:depth-idx)                   Param #
Sequential                               --
├─Linear: 1-1                            787,968
├─ELU: 1-2                               --
├─ReshapeLayer: 1-3                      --
├─ConvTranspose2d: 1-4                   7,372,992
├─ELU: 1-5                               --
├─ConvTranspose2d: 1-6                   460,896
├─ELU: 1-7                               --
├─ConvTranspose2d: 1-8             

torch.Size([1, 3, 128, 128])

## Decoder tries to invert the Encoder structure

In [10]:
class ReshapeLayer(nn.Module):
    def __init__(self, output_shape):
        super().__init__()
        self.output_shape = output_shape
    
    def forward(self, x):
        B = x.shape[0]
        return x.reshape(B, *self.output_shape)

decoder = nn.Sequential(
    nn.Linear(512, 1536),
    nn.ELU(True),
    ReshapeLayer([384, 2, 2]),

    nn.ConvTranspose2d(
        in_channels=384,
        out_channels=384,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=384,
        out_channels=192,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=192,
        out_channels=96,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=96,
        out_channels=48,
        kernel_size=4,
        stride=2
    ),
    nn.ELU(True),

    nn.ConvTranspose2d(
        in_channels=48,
        out_channels=3,
        kernel_size=6,
        stride=2
    )

) ; decoder
print(decoder)
print(summary(decoder))

x = th.zeros([1, 512])

decoder(x).shape

Sequential(
  (0): Linear(in_features=512, out_features=1536, bias=True)
  (1): ELU(alpha=True)
  (2): ReshapeLayer()
  (3): ConvTranspose2d(384, 384, kernel_size=(4, 4), stride=(2, 2))
  (4): ELU(alpha=True)
  (5): ConvTranspose2d(384, 192, kernel_size=(4, 4), stride=(2, 2))
  (6): ELU(alpha=True)
  (7): ConvTranspose2d(192, 96, kernel_size=(4, 4), stride=(2, 2))
  (8): ELU(alpha=True)
  (9): ConvTranspose2d(96, 48, kernel_size=(4, 4), stride=(2, 2))
  (10): ELU(alpha=True)
  (11): ConvTranspose2d(48, 3, kernel_size=(6, 6), stride=(2, 2))
)
Layer (type:depth-idx)                   Param #
Sequential                               --
├─Linear: 1-1                            787,968
├─ELU: 1-2                               --
├─ReshapeLayer: 1-3                      --
├─ConvTranspose2d: 1-4                   2,359,680
├─ELU: 1-5                               --
├─ConvTranspose2d: 1-6                   1,179,840
├─ELU: 1-7                               --
├─ConvTranspose2d: 1-8          

torch.Size([1, 3, 128, 128])