In [None]:
import torch
from torch import nn
from collections import OrderedDict

In [ ]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 input_shape, 
                 conv_filters, 
                 conv_kernels, 
                 conv_strides, 
                 latent_space_dim):
        super(Encoder, self).__init__()
        
        self.input_shape = input_shape
        self.conv_filters = conv_filters
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.latent_space_dim = latent_space_dim
        
        self.num_conv_layers = len(self.conv_filters)
        self.shape_before_bottleneck = None
        
        self.conv_layers = self._build_conv_layers()
        self.flatten = nn.Flatten()
        self.dense_layer = self._build_dense_layer()
    
    def _build_conv_layers(self):
        layers = []
        # input shape (C, W, H)
        in_channels = self.input_shape[0]
        for i in range(self.num_conv_layers):
            layers.append((
                f'conv{i+1}', 
                nn.Sequential(
                    nn.Conv2d(in_channels=in_channels,
                              out_channels=self.conv_filters[i],
                              kernel_size=self.conv_kernels[i],
                              stride=self.conv_strides[i],
                              padding=(self.conv_kernels[i]-1) // 2),
                    nn.ReLU(),
                    nn.BatchNorm2d(self.conv_filters[i])
                )
            ))
            in_channels = self.conv_filters[i]
        return nn.Sequential(OrderedDict(layers))
        
    def _build_dense_layer(self):
        dummy_input = torch.zeros(1, *self.input_shape)
        conv_out = self.conv_layers(dummy_input)
        self.shape_before_bottleneck = conv_out.shape[1:]
        flattened_size = conv_out.numel()
        return nn.Linear(in_features=flattened_size, 
                         out_features=self.latent_space_dim)
    
    def forward(self, x):
        return self.dense_layer(self.flatten(self.conv_layers(x)))
    
    def summary(self):
        return print(self)