# Building Convolutional Autoencoder

In [1]:
import torch
from torch import nn
from collections import OrderedDict
from torchsummary import summary

### Convolutional Encoder

In [2]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [3]:
class ConvolutionalBlock(nn.Module):
    
    def __init__(self,
                 input_shape : tuple[int, int, int],
                 convolutional_filters: tuple[int:],
                 convolutional_kernels: tuple[int:],
                 convolutional_strides: tuple[int:]):
        super(ConvolutionalBlock, self).__init__()
        
        self.input_shape = input_shape
        self.convolutional_filters = convolutional_filters
        self.convolutional_kernels = convolutional_kernels
        self.convolutional_strides = convolutional_strides
        self.num_conv_layers = len(self.convolutional_filters)
        
        self.conv_layers = self._build_convolutional_layers()
    
    def _build_convolutional_layers(self):
        layers = []
        in_channels = self.input_shape[0]
        for i in range(self.num_conv_layers):
            layers.append((
                f'conv{i+1}',
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=self.convolutional_filters[i],
                        kernel_size=self.convolutional_kernels[i],
                        stride=self.convolutional_strides[i],
                        padding=(self.convolutional_kernels[i] - 1) // 2
                    ),
                    nn.ReLU(),
                    nn.BatchNorm2d(self.convolutional_filters[i])
                )
            ))
            in_channels = self.convolutional_filters[i]
        return nn.Sequential(OrderedDict(layers))
    
    def forward(self, x):
        return self.conv_layers(x)
    
    def summary(self):
        print(self)
        

class ConvolutionalEncoder(ConvolutionalBlock):
    
    def __init__(self,
                 input_shape : tuple[int, int, int],
                 convolutional_filters: tuple[int:],
                 convolutional_kernels: tuple[int:],
                 convolutional_strides: tuple[int:],
                 latent_space_dimension: int = 2):
        super(ConvolutionalEncoder, self).__init__(
            input_shape=input_shape,
            convolutional_filters=convolutional_filters,
            convolutional_kernels=convolutional_kernels,
            convolutional_strides=convolutional_strides
        )
        self.latent_space_dim = latent_space_dimension
        self.shape_before_bottleneck = None
        self.shape_flattened = None
        
        self.output_layer = self._build_output_layer()
    
    def _build_output_layer(self):
        dummy_input = torch.zeros(1, *self.input_shape)
        conv_out = self.conv_layers(dummy_input)
        self.shape_before_bottleneck = conv_out.shape[1:]
        self.shape_flattened = conv_out.numel()
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=self.shape_flattened,
                      out_features=self.latent_space_dim))
    
    def forward(self, x):
        return self.output_layer(self.conv_layers(x))
    
    def summary(self):
        print(self)

In [4]:
INPUT_SHAPE = [1, 28, 28]
conv_filters = [32, 64, 64, 64]
conv_kernels = [3, 3, 3, 3]
conv_strides = [1, 2, 2, 1]
latent_space_dim = 2
encoder = ConvolutionalEncoder(
    input_shape=INPUT_SHAPE,
    convolutional_filters=conv_filters,
    convolutional_kernels=conv_kernels,
    convolutional_strides=conv_strides,
    latent_space_dimension=latent_space_dim
)
encoder.summary()

ConvolutionalEncoder(
  (conv_layers): Sequential(
    (conv1): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv4): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (output_layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1

In [5]:
summary(encoder, input_size=tuple(INPUT_SHAPE), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
       BatchNorm2d-3           [-1, 32, 28, 28]              64
            Conv2d-4           [-1, 64, 14, 14]          18,496
              ReLU-5           [-1, 64, 14, 14]               0
       BatchNorm2d-6           [-1, 64, 14, 14]             128
            Conv2d-7             [-1, 64, 7, 7]          36,928
              ReLU-8             [-1, 64, 7, 7]               0
       BatchNorm2d-9             [-1, 64, 7, 7]             128
           Conv2d-10             [-1, 64, 7, 7]          36,928
             ReLU-11             [-1, 64, 7, 7]               0
      BatchNorm2d-12             [-1, 64, 7, 7]             128
          Flatten-13                 [-1, 3136]               0
           Linear-14                   

In [6]:
test_tensor = torch.randn(3, 1, 28, 28)
encoder_output = encoder(test_tensor)
encoder_output

tensor([[-0.1759, -0.3644],
        [ 1.6390, -0.6403],
        [ 0.0442, -0.0863]], grad_fn=<AddmmBackward0>)

In [7]:
encoder.shape_before_bottleneck

torch.Size([64, 7, 7])

Let's now build a decoder using convolution transpose

In [8]:
class ConvolutionalTransposeBlock(nn.Module):
    
    def __init__(self,
                 shape_before_bottleneck: tuple[int:],
                 convolutional_transpose_filters: tuple[int:],
                 convolutional_transpose_kernels: tuple[int:],
                 convolutional_transpose_strides: tuple[int:]):
        super(ConvolutionalTransposeBlock, self).__init__()
        self.shape_before_bottleneck = shape_before_bottleneck
        self.convolutional_transpose_filters = convolutional_transpose_filters
        self.convolutional_transpose_kernels = convolutional_transpose_kernels
        self.convolutional_transpose_strides = convolutional_transpose_strides
        self.num_convT_filters = len(convolutional_transpose_filters)
        
        self.convT_layers = self._build_convolutional_transpose_layers()
    
    def _build_convolutional_transpose_layers(self):
        layers = []
        in_channels = self.shape_before_bottleneck[0]
        for i in range(self.num_convT_filters):
            padding = (self.convolutional_transpose_kernels[i] - 1) // 2
            output_padding = (self.convolutional_transpose_strides[i] -
                              self.convolutional_transpose_kernels[i] + (2 * padding))
            layers.append((
                f'convTranspose{i+1}',
                nn.Sequential(
                    nn.ConvTranspose2d(
                        in_channels=in_channels,
                        out_channels=self.convolutional_transpose_filters[i],
                        kernel_size=self.convolutional_transpose_kernels[i],
                        stride=self.convolutional_transpose_strides[i],
                        padding=padding,
                        output_padding=output_padding
                    ),
                    nn.ReLU(),
                    nn.BatchNorm2d(self.convolutional_transpose_filters[i])
                )
            ))
            in_channels = self.convolutional_transpose_filters[i]
        return nn.Sequential(OrderedDict(layers))
    
    def forward(self, x):
        return self.convT_layers(x)
    
    def summary(self):
        print(self)


class ConvolutionalDecoder(ConvolutionalTransposeBlock):
    
    def __init__(self,
                 latent_space_dimension: int,
                 shape_before_bottleneck: tuple[int:],
                 convolutional_transpose_filters: tuple[int:],
                 convolutional_transpose_kernels: tuple[int:],
                 convolutional_transpose_strides: tuple[int:],
                 out_channels: int = 3):
        super(ConvolutionalDecoder, self).__init__(
            shape_before_bottleneck=shape_before_bottleneck,
            convolutional_transpose_filters=convolutional_transpose_filters,
            convolutional_transpose_kernels=convolutional_transpose_kernels,
            convolutional_transpose_strides=convolutional_transpose_strides
        )
        self.latent_space_dimension = latent_space_dimension
        self.shape_before_bottleneck = shape_before_bottleneck
        self.out_channels = out_channels
        
        self.dense_layer = self._build_dense_layer()
        self.output_layer = self._build_output_layer()
    
    def _build_dense_layer(self):
        flattened_size = self.shape_before_bottleneck[0] * \
                         self.shape_before_bottleneck[1] * \
                         self.shape_before_bottleneck[2]
        return nn.Linear(in_features=self.latent_space_dimension,
                         out_features=flattened_size)
    
    def _build_output_layer(self):
        padding = (self.convolutional_transpose_kernels[-1] - 1) // 2
        output_padding = (self.convolutional_transpose_strides[-1] - 
                          self.convolutional_transpose_kernels[-1] + (2 * padding))
        output_convT = nn.ConvTranspose2d(
            in_channels=self.convolutional_transpose_filters[-1],
            out_channels=self.out_channels,
            kernel_size=self.convolutional_transpose_kernels[-1],
            stride=self.convolutional_transpose_strides[-1],
            padding=padding,
            output_padding=output_padding
        )
        return nn.Sequential(output_convT, nn.Sigmoid())
    
    def forward(self, x):
        return self.output_layer(self.convT_layers(
            self.dense_layer(x).view(x.size(0), *self.shape_before_bottleneck)))
    
    def summary(self):
        print(self)

In [9]:
decoder = ConvolutionalDecoder(
    latent_space_dimension=latent_space_dim,
    shape_before_bottleneck=encoder.shape_before_bottleneck,
    convolutional_transpose_filters=conv_filters[::-1],
    convolutional_transpose_kernels=conv_kernels[::-1],
    convolutional_transpose_strides=conv_strides[::-1],
    out_channels=INPUT_SHAPE[0]
)
decoder.summary()

ConvolutionalDecoder(
  (convT_layers): Sequential(
    (convTranspose1): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (convTranspose2): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (convTranspose3): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (convTranspose4): Sequential(
      (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=

In [10]:
summary(decoder, input_size=encoder_output.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 3136]           9,408
   ConvTranspose2d-2             [-1, 64, 7, 7]          36,928
              ReLU-3             [-1, 64, 7, 7]               0
       BatchNorm2d-4             [-1, 64, 7, 7]             128
   ConvTranspose2d-5           [-1, 64, 14, 14]          36,928
              ReLU-6           [-1, 64, 14, 14]               0
       BatchNorm2d-7           [-1, 64, 14, 14]             128
   ConvTranspose2d-8           [-1, 64, 28, 28]          36,928
              ReLU-9           [-1, 64, 28, 28]               0
      BatchNorm2d-10           [-1, 64, 28, 28]             128
  ConvTranspose2d-11           [-1, 32, 28, 28]          18,464
             ReLU-12           [-1, 32, 28, 28]               0
      BatchNorm2d-13           [-1, 32, 28, 28]              64
  ConvTranspose2d-14            [-1, 1,

In [11]:
decoder_output = decoder(encoder_output)
decoder_output.shape

torch.Size([3, 1, 28, 28])

Autoencoder

Let's now construct an Autoencoder based on the following formula

In [12]:
class ConvolutionalAutoencoder(nn.Module):
    
    def __init__(self,
                 input_shape : tuple[int, int, int],
                 convolutional_filters: tuple[int:],
                 convolutional_kernels: tuple[int:],
                 convolutional_strides: tuple[int:],
                 latent_space_dimension: int = 2,
                 convolutional_transpose_filters: tuple[int:] = None,
                 convolutional_transpose_kernels: tuple[int:] = None,
                 convolutional_transpose_strides: tuple[int:] = None):
        super(ConvolutionalAutoencoder, self).__init__()
        
        if convolutional_transpose_filters is None:
            convolutional_transpose_filters = convolutional_filters[::-1]
        if convolutional_transpose_kernels is None:
            convolutional_transpose_kernels = convolutional_kernels[::-1]
        if convolutional_transpose_strides is None:
            convolutional_transpose_strides = convolutional_strides[::-1]
        
        self.encoder = ConvolutionalEncoder(
            input_shape=input_shape,
            convolutional_filters=convolutional_filters,
            convolutional_kernels=convolutional_kernels,
            convolutional_strides=convolutional_strides,
            latent_space_dimension=latent_space_dimension
        )
        self.decoder = ConvolutionalDecoder(
            latent_space_dimension=latent_space_dimension,
            shape_before_bottleneck=self.encoder.shape_before_bottleneck,
            convolutional_transpose_filters=convolutional_transpose_filters,
            convolutional_transpose_kernels=convolutional_transpose_kernels,
            convolutional_transpose_strides=convolutional_transpose_strides,
            out_channels=input_shape[0]
        )
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def summary(self):
        print(self)

In [13]:
autoencoder = ConvolutionalAutoencoder(
    input_shape=INPUT_SHAPE,
    convolutional_filters=conv_filters,
    convolutional_kernels=conv_kernels,
    convolutional_strides=conv_strides,
    latent_space_dimension=3
)
autoencoder.summary()

ConvolutionalAutoencoder(
  (encoder): ConvolutionalEncoder(
    (conv_layers): Sequential(
      (conv1): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv3): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv4): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
 

In [14]:
summary(autoencoder, input_size=tuple(INPUT_SHAPE), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
       BatchNorm2d-3           [-1, 32, 28, 28]              64
            Conv2d-4           [-1, 64, 14, 14]          18,496
              ReLU-5           [-1, 64, 14, 14]               0
       BatchNorm2d-6           [-1, 64, 14, 14]             128
            Conv2d-7             [-1, 64, 7, 7]          36,928
              ReLU-8             [-1, 64, 7, 7]               0
       BatchNorm2d-9             [-1, 64, 7, 7]             128
           Conv2d-10             [-1, 64, 7, 7]          36,928
             ReLU-11             [-1, 64, 7, 7]               0
      BatchNorm2d-12             [-1, 64, 7, 7]             128
          Flatten-13                 [-1, 3136]               0
           Linear-14                   

In [16]:
autoencoder_output = autoencoder(test_tensor)
autoencoder_output.shape

torch.Size([3, 1, 28, 28])