In [10]:
import torch
from models.vae.model import VAEEncoder, VAEDecoder
from torchinfo import summary

# Encoder-Decoder Parameters Analysis

## Conv + Linear HParams

In [12]:
input_dim = 3 
encoder_hidden_dims = [16, 32, 64, 128, 256, 512]
encoder_kernels = [1, 5, 5, 4, 3, 4]
encoder_strides = [1, 1, 1, 2, 2, 1]
encoder_paddings = [0, 0, 0, 0, 0, 0]
encoder_last_layer_fc = True
encoder_last_spatial_dim = 26
latent_dim = 64
decoder_hidden_dims= [512, 256, 128, 64, 32, 16]
decoder_kernels = [-1, 4, 3, 4, 5, 5, 4]
decoder_strides = [-1, 1, 2, 2, 1, 1, 1]
decoder_paddings = [-1, 0, 0, 0, 0, 0, 0]

## Conv only HParams

In [2]:
input_dim = 3 
encoder_hidden_dims = [16, 32, 64, 128, 256, 512]
encoder_kernels = [1, 5, 5, 4, 3, 4]
encoder_strides = [1, 1, 1, 2, 2, 1]
encoder_paddings = [0, 0, 0, 0, 0, 0]
encoder_last_layer_fc = False
encoder_last_spatial_dim = 26
latent_dim = 64
decoder_hidden_dims= [256, 128, 64, 32, 16, 16]
decoder_kernels = [4, 3, 4, 5, 5, 4]
decoder_strides = [1, 2, 2, 1, 1, 1]
decoder_paddings = [0, 0, 0, 0, 0, 0]

In [13]:
encoder = VAEEncoder(input_dim=input_dim, 
                    encoder_hidden_dims=encoder_hidden_dims,
                    encoder_kernels=encoder_kernels,
                    encoder_strides=encoder_strides,
                    encoder_paddings=encoder_paddings,
                    encoder_last_layer_fc=encoder_last_layer_fc,
                    encoder_last_spatial_dim=encoder_last_spatial_dim,
                    latent_dim=latent_dim)

In [14]:
x = torch.randn(10, input_dim, 128, 128)

In [15]:
encoder(x)

Normal(loc: torch.Size([10, 64]), scale: torch.Size([10, 64]))

# Encoder (Conv only)

In [13]:
summary(encoder, input_size=(10, 3, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
VAEEncoder                               [10, 64, 26, 26]          32,832
├─ModuleList: 1-1                        --                        --
│    └─Conv2d: 2-1                       [10, 16, 128, 128]        64
│    └─Conv2d: 2-2                       [10, 32, 124, 124]        12,832
│    └─Conv2d: 2-3                       [10, 64, 120, 120]        51,264
│    └─Conv2d: 2-4                       [10, 128, 59, 59]         131,200
│    └─Conv2d: 2-5                       [10, 256, 29, 29]         295,168
│    └─Conv2d: 2-6                       [10, 512, 26, 26]         2,097,664
├─Conv2d: 1-2                            [10, 64, 26, 26]          32,832
Total params: 2,653,856
Trainable params: 2,653,856
Non-trainable params: 0
Total mult-adds (G): 30.82
Input size (MB): 1.97
Forward/backward pass size (MB): 218.08
Params size (MB): 10.48
Estimated Total Size (MB): 230.53

# Encoder Conv + Linear

In [14]:
summary(encoder, input_size=(10, 3, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
VAEEncoder                               [10, 64]                  22,151,232
├─ModuleList: 1-1                        --                        --
│    └─Conv2d: 2-1                       [10, 16, 128, 128]        64
│    └─Conv2d: 2-2                       [10, 32, 124, 124]        12,832
│    └─Conv2d: 2-3                       [10, 64, 120, 120]        51,264
│    └─Conv2d: 2-4                       [10, 128, 59, 59]         131,200
│    └─Conv2d: 2-5                       [10, 256, 29, 29]         295,168
│    └─Conv2d: 2-6                       [10, 512, 26, 26]         2,097,664
├─Linear: 1-2                            [10, 64]                  22,151,232
Total params: 46,890,656
Trainable params: 46,890,656
Non-trainable params: 0
Total mult-adds (G): 30.82
Input size (MB): 1.97
Forward/backward pass size (MB): 214.63
Params size (MB): 98.96
Estimated Total Size (MB): 315.55

In [16]:
decoder = VAEDecoder(input_dim=input_dim,
                    encoder_last_layer_fc=encoder_last_layer_fc,
                    encoder_last_spatial_dim=encoder_last_spatial_dim,
                    latent_dim=latent_dim,
                    decoder_hidden_dims=decoder_hidden_dims,
                    decoder_kernels=decoder_kernels,
                    decoder_strides=decoder_strides,
                    decoder_paddings=decoder_paddings,
                    decoder_output_sigmoid=False)

In [17]:
q_z = encoder(x)
z = q_z.rsample()
z.shape

torch.Size([10, 64])

# Decoder (Conv only)

In [9]:
summary(decoder, input_size=(10, latent_dim, encoder_last_spatial_dim, encoder_last_spatial_dim))

Layer (type:depth-idx)                   Output Shape              Param #
VAEDecoder                               [10, 3, 131, 131]         --
├─ModuleList: 1-1                        --                        --
│    └─ConvTranspose2d: 2-1              [10, 256, 29, 29]         262,400
│    └─ConvTranspose2d: 2-2              [10, 128, 59, 59]         295,040
│    └─ConvTranspose2d: 2-3              [10, 64, 120, 120]        131,136
│    └─ConvTranspose2d: 2-4              [10, 32, 124, 124]        51,232
│    └─ConvTranspose2d: 2-5              [10, 16, 128, 128]        12,816
│    └─ConvTranspose2d: 2-6              [10, 16, 131, 131]        4,112
├─ConvTranspose2d: 1-2                   [10, 3, 131, 131]         51
Total params: 756,787
Trainable params: 756,787
Non-trainable params: 0
Total mult-adds (G): 42.05
Input size (MB): 1.73
Forward/backward pass size (MB): 213.02
Params size (MB): 3.03
Estimated Total Size (MB): 217.77

# Decoder (Conv + Linear)

In [18]:
summary(decoder, input_size=(10, latent_dim))

Layer (type:depth-idx)                   Output Shape              Param #
VAEDecoder                               [10, 3, 128, 128]         --
├─ModuleList: 1-1                        --                        --
│    └─Linear: 2-1                       [10, 346112]              22,497,280
│    └─ConvTranspose2d: 2-2              [10, 256, 29, 29]         2,097,408
│    └─ConvTranspose2d: 2-3              [10, 128, 59, 59]         295,040
│    └─ConvTranspose2d: 2-4              [10, 64, 120, 120]        131,136
│    └─ConvTranspose2d: 2-5              [10, 32, 124, 124]        51,232
│    └─ConvTranspose2d: 2-6              [10, 16, 128, 128]        12,816
├─ConvTranspose2d: 1-2                   [10, 3, 128, 128]         51
Total params: 25,084,963
Trainable params: 25,084,963
Non-trainable params: 0
Total mult-adds (G): 57.00
Input size (MB): 0.00
Forward/backward pass size (MB): 218.55
Params size (MB): 100.34
Estimated Total Size (MB): 318.89