# Building Dense Autoencoders


In [1]:
import torch
from torch import nn
from torchsummary import summary
from collections import OrderedDict

## Simple MLP Encoder

In [2]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [3]:
class HiddenBlock(nn.Module):
    
    def __init__(self, 
                 input_size: int, 
                 hidden_layers: tuple[int:] = None,
                 dropout: float = 0.2):
        super(HiddenBlock, self).__init__()
        if hidden_layers is None:
            self.hidden_layers = []
        elif isinstance(hidden_layers, int):
            self.hidden_layers = [hidden_layers]
        else:
            self.hidden_layers = hidden_layers
        self.input_size = input_size
        self.num_of_layers = len(self.hidden_layers)
        self.dropout = dropout
        
        self.dense_layers = self._build_dense_layers()
    
    def _build_dense_layers(self):
        layers = []
        if self.num_of_layers == 0:
            return nn.Identity()
        else:
            in_features = self.input_size
            for index, layer_size in enumerate(self.hidden_layers, start=1):
                layers.append((
                    f'dense_layer{index}',
                    nn.Sequential(
                        nn.Linear(in_features=in_features,
                                  out_features=layer_size),
                        nn.ReLU(),
                        nn.LayerNorm(layer_size),
                        nn.Dropout(p=self.dropout)
                    )
                ))
                in_features = layer_size
        return nn.Sequential(OrderedDict(layers))
    
    def forward(self, x):
        return self.dense_layers(x)
    
    def summary(self):
        print(self)


class DenseEncoder(HiddenBlock):
    
    def __init__(self,
                 input_shape: tuple[int, int, int],
                 hidden_layers: tuple[int:] = None,
                 latent_space_dim: int = 2,
                 dropout: float = 0.2):
        self.input_shape = input_shape
        input_size = input_shape[0] * input_shape[1] * input_shape[2]
        super(DenseEncoder, self).__init__(input_size=input_size,
                                           hidden_layers=hidden_layers,
                                           dropout=dropout)
        self.latent_space_dim = latent_space_dim
        
        self.flatten = nn.Flatten()
        self.output_layer = self._build_output_layer()
    
    def _build_output_layer(self):
        if self.num_of_layers == 0:
            in_features = self.input_size
        else:
            in_features = self.hidden_layers[-1]
        return nn.Linear(in_features=in_features,
                         out_features=self.latent_space_dim)
    
    def forward(self, x):
        return self.output_layer(self.dense_layers(self.flatten(x)))


class DenseDecoder(HiddenBlock):
    
    def __init__(self,
                 output_shape: tuple[int, int, int],
                 hidden_layers: tuple[int:] = None,
                 latent_space_dim: int = 2,
                 dropout: float = 0.2):
        super(DenseDecoder, self).__init__(input_size=latent_space_dim,
                                           hidden_layers=hidden_layers,
                                           dropout=dropout)
        self.output_shape = output_shape
        self.latent_space_dim = latent_space_dim
        self.output_layer = self._build_output_layer()
    
    def _build_output_layer(self):
        if self.num_of_layers == 0:
            in_features = self.latent_space_dim
        else:
            in_features = self.hidden_layers[-1]
        out_features = self.output_shape[0] * self.output_shape[1] * self.output_shape[2]
        return nn.Sequential(nn.Linear(in_features=in_features,
                                       out_features=out_features),
                             nn.Sigmoid())
    
    def forward(self, x):
        return self.output_layer(self.dense_layers(x)).view(x.size(0), *self.output_shape)

In [4]:
dummy_tensor = torch.randn(1, 1, 28, 28)
input_shape = [1, 28, 28]
hidden_layers = [256]
latent_space_dim = 2
DenseEncoder(input_shape).summary()

DenseEncoder(
  (dense_layers): Identity()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (output_layer): Linear(in_features=784, out_features=2, bias=True)
)


In [5]:
model = DenseEncoder(
    input_shape=input_shape,
    hidden_layers=hidden_layers,
    latent_space_dim=latent_space_dim
)
model.summary()
latent_dummy_tensor = model(dummy_tensor)
latent_dummy_tensor.shape

DenseEncoder(
  (dense_layers): Sequential(
    (dense_layer1): Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (3): Dropout(p=0.2, inplace=False)
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (output_layer): Linear(in_features=256, out_features=2, bias=True)
)


torch.Size([1, 2])

In [6]:
model = DenseDecoder(
    output_shape=input_shape,
    hidden_layers=hidden_layers[::-1],
    latent_space_dim=latent_space_dim
)
model.summary()
model(latent_dummy_tensor).shape

DenseDecoder(
  (dense_layers): Sequential(
    (dense_layer1): Sequential(
      (0): Linear(in_features=2, out_features=256, bias=True)
      (1): ReLU()
      (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (3): Dropout(p=0.2, inplace=False)
    )
  )
  (output_layer): Sequential(
    (0): Linear(in_features=256, out_features=784, bias=True)
    (1): Sigmoid()
  )
)


torch.Size([1, 1, 28, 28])

Now let's build an Autoencoder

In [7]:
class DenseAutoencoder(nn.Module):
    
    def __init__(self,
                 input_shape: tuple[int, int, int],
                 encoder_hidden_layers: tuple[int:] = None,
                 decoder_hidden_layers: tuple[int:] = None,
                 latent_space_dim: int = 2,
                 dropout: float = 0.1):
        super(DenseAutoencoder, self).__init__()
        
        self.encoder = DenseEncoder(input_shape=input_shape,
                                    hidden_layers=encoder_hidden_layers,
                                    latent_space_dim=latent_space_dim,
                                    dropout=dropout) 
        
        if decoder_hidden_layers is None:
            decoder_hidden_layers = self.encoder.hidden_layers[::-1]
        self.decoder = DenseDecoder(output_shape=input_shape,
                                    hidden_layers=decoder_hidden_layers,
                                    latent_space_dim=latent_space_dim,
                                    dropout=dropout)
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def summary(self):
        print(self)

In [8]:
model = DenseAutoencoder(
    input_shape=input_shape,
    encoder_hidden_layers=hidden_layers,
    latent_space_dim=latent_space_dim
)
model.summary()
model(dummy_tensor).shape

DenseAutoencoder(
  (encoder): DenseEncoder(
    (dense_layers): Sequential(
      (dense_layer1): Sequential(
        (0): Linear(in_features=784, out_features=256, bias=True)
        (1): ReLU()
        (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (output_layer): Linear(in_features=256, out_features=2, bias=True)
  )
  (decoder): DenseDecoder(
    (dense_layers): Sequential(
      (dense_layer1): Sequential(
        (0): Linear(in_features=2, out_features=256, bias=True)
        (1): ReLU()
        (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (output_layer): Sequential(
      (0): Linear(in_features=256, out_features=784, bias=True)
      (1): Sigmoid()
    )
  )
)


torch.Size([1, 1, 28, 28])

In [9]:
summary(model, input_size=tuple(input_shape), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
            Linear-2                  [-1, 256]         200,960
              ReLU-3                  [-1, 256]               0
         LayerNorm-4                  [-1, 256]             512
           Dropout-5                  [-1, 256]               0
            Linear-6                    [-1, 2]             514
      DenseEncoder-7                    [-1, 2]               0
            Linear-8                  [-1, 256]             768
              ReLU-9                  [-1, 256]               0
        LayerNorm-10                  [-1, 256]             512
          Dropout-11                  [-1, 256]               0
           Linear-12                  [-1, 784]         201,488
          Sigmoid-13                  [-1, 784]               0
     DenseDecoder-14            [-1, 1,