# Dense Variational Autoencoder

In [1]:
import torch
from torch import nn
from torchsummary import summary
from functools import reduce
import operator

In [2]:
import sys
sys.path.append('..')
from ModelClasses.Autoencoders import DenseBlock, DenseDecoder

In [3]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [4]:
class DenseVariationalEncoder(DenseBlock):
    
    def __init__(self,
                 input_shape: tuple[int:],
                 hidden_layers: tuple[int:] = None,
                 latent_space_dimension: int = 2,
                 dropout: float = 0.2):
        self.input_shape = input_shape
        input_size = reduce(operator.mul, input_shape)
        super(DenseVariationalEncoder, self).__init__(
            input_size=input_size,
            hidden_layers=hidden_layers,
            dropout=dropout)
        self.latent_space_dim = latent_space_dimension
        
        self.flatten = nn.Flatten()
        self.dense_mu, self.dense_logvar = self._build_variational_layer()
    
    def _build_variational_layer(self):
        if self.num_of_layers == 0:
            in_features = self.input_size
        else:
            in_features = self.hidden_layers[-1]
        mu = nn.Linear(in_features=in_features,
                       out_features=self.latent_space_dim)
        logvar = nn.Linear(in_features=in_features,
                           out_features=self.latent_space_dim)
        return mu, logvar
    
    def forward(self, x):
        x = self.dense_layers(self.flatten(x))
        return self.dense_mu(x), self.dense_logvar(x)
    
    def summary(self):
        print(self)

In [5]:
INPUT_SHAPE = [1, 28, 28]
hidden_layers = [256, 128]
latent_space_dim = 2

encoder = DenseVariationalEncoder(
    input_shape=INPUT_SHAPE,
    hidden_layers=hidden_layers,
    latent_space_dimension=latent_space_dim
)
encoder.summary()

DenseVariationalEncoder(
  (dense_layers): Sequential(
    (dense_layer1): Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (3): Dropout(p=0.2, inplace=False)
    )
    (dense_layer2): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): ReLU()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (3): Dropout(p=0.2, inplace=False)
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dense_mu): Linear(in_features=128, out_features=2, bias=True)
  (dense_logvar): Linear(in_features=128, out_features=2, bias=True)
)


In [6]:
summary(encoder, input_size=tuple(INPUT_SHAPE), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
            Linear-2                  [-1, 256]         200,960
              ReLU-3                  [-1, 256]               0
         LayerNorm-4                  [-1, 256]             512
           Dropout-5                  [-1, 256]               0
            Linear-6                  [-1, 128]          32,896
              ReLU-7                  [-1, 128]               0
         LayerNorm-8                  [-1, 128]             256
           Dropout-9                  [-1, 128]               0
           Linear-10                    [-1, 2]             258
           Linear-11                    [-1, 2]             258
Total params: 235,140
Trainable params: 235,140
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/

In [7]:
test_tensor = torch.randn(3, *INPUT_SHAPE)
encoder_mu, encoder_logvar = encoder(test_tensor)
print(encoder_mu)
print(encoder_logvar)

tensor([[-0.1363, -0.0058],
        [ 0.0648, -0.6677],
        [ 0.6452,  0.2945]], grad_fn=<AddmmBackward0>)
tensor([[ 0.1427,  1.1920],
        [-1.5359,  0.3155],
        [ 0.1362, -0.0895]], grad_fn=<AddmmBackward0>)


In [8]:
class DenseVariationalAutoencoder(nn.Module):
    
    def __init__(self,
                 input_shape: tuple[int:],
                 encoder_hidden_layers: tuple[int:] = None,
                 decoder_hidden_layers: tuple[int:] = None,
                 latent_space_dimension: int = 2,
                 dropout: float = 0.1):
        super(DenseVariationalAutoencoder, self).__init__()

        self.encoder = DenseVariationalEncoder(
            input_shape=input_shape,
            hidden_layers=encoder_hidden_layers,
            latent_space_dimension=latent_space_dimension,
            dropout=dropout)

        if decoder_hidden_layers is None:
            decoder_hidden_layers = self.encoder.hidden_layers[::-1]
        self.decoder = DenseDecoder(
            output_shape=input_shape,
            hidden_layers=decoder_hidden_layers,
            latent_space_dimension=latent_space_dimension,
            dropout=dropout)
    
    @staticmethod    
    def reparametrization(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + (eps * std)
    
    def encode_and_reparametrize(self, x):
        return self.reparametrization(*self.encoder(x))
    
    def forward_mu_logvar(self, mu, logvar):
        z = self.reparametrization(mu, logvar)
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        return self.forward_mu_logvar(mu, logvar), mu, logvar
    
    def summary(self):
        print(self)

In [9]:
vae = DenseVariationalAutoencoder(
    input_shape=INPUT_SHAPE,
    encoder_hidden_layers=hidden_layers,
    latent_space_dimension=latent_space_dim
)
vae.summary()

DenseVariationalAutoencoder(
  (encoder): DenseVariationalEncoder(
    (dense_layers): Sequential(
      (dense_layer1): Sequential(
        (0): Linear(in_features=784, out_features=256, bias=True)
        (1): ReLU()
        (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dense_layer2): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): ReLU()
        (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (dense_mu): Linear(in_features=128, out_features=2, bias=True)
    (dense_logvar): Linear(in_features=128, out_features=2, bias=True)
  )
  (decoder): DenseDecoder(
    (dense_layers): Sequential(
      (dense_layer1): Sequential(
        (0): Linear(in_features=2, out_features=128, bias=True)
        (1): ReLU()
        (2): LayerNorm((128,), eps=1e-05, elementwis

In [10]:
summary(vae, input_size=tuple(INPUT_SHAPE), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
            Linear-2                  [-1, 256]         200,960
              ReLU-3                  [-1, 256]               0
         LayerNorm-4                  [-1, 256]             512
           Dropout-5                  [-1, 256]               0
            Linear-6                  [-1, 128]          32,896
              ReLU-7                  [-1, 128]               0
         LayerNorm-8                  [-1, 128]             256
           Dropout-9                  [-1, 128]               0
           Linear-10                    [-1, 2]             258
           Linear-11                    [-1, 2]             258
DenseVariationalEncoder-12         [[-1, 2], [-1, 2]]               0
           Linear-13                  [-1, 128]             384
             ReLU-14             

In [11]:
decoder_output = vae.forward_mu_logvar(encoder_mu, encoder_logvar)
decoder_output.shape

torch.Size([3, 1, 28, 28])

In [12]:
decoder_output, _, _ = vae(test_tensor)
decoder_output.shape

torch.Size([3, 1, 28, 28])