# Encoder

> A customisable pytorch variational encoder model.

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp Models.VariationalEncoder

In [None]:
#| export
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class VariationalEncoder(nn.Module):
    """ Variational Encoder pytorch model
    """
    def __init__(self, input_size, hidden_sizes, latent_size, dropout, use_norm):
        super().__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.latent_size = latent_size
        self.dropout = dropout
        self.use_norm = use_norm

        # create a list of layers
        layers = []

        # input layer
        layers.append(nn.Linear(self.input_size, self.hidden_sizes[0]))
        layers.append(nn.ReLU(0.2))
        if self.dropout > 0:
            layers.append(nn.Dropout(p=self.dropout))

        # hidden layers
        for i in range(1, len(self.hidden_sizes)):
            layers.append(nn.Linear(self.hidden_sizes[i-1], self.hidden_sizes[i]))
            if self.use_norm:
                layers.append(nn.InstanceNorm1d(self.hidden_sizes[i]))
            layers.append(nn.ReLU(0.2))
            if self.dropout > 0:
                layers.append(nn.Dropout(p=self.dropout))
            
        layers.append(nn.Linear(self.hidden_sizes[-1], self.latent_size))
      
        # create the model using Sequential
        self.model = nn.Sequential(*layers)
        self.mu_layer = nn.Linear(self.hidden_sizes[-1], self.latent_size)
        self.logvar_layer = nn.Linear(self.hidden_sizes[-1], self.latent_size)
       
        

    def forward(self, x):
        x = self.model(x)
        mu = self.mu_layer(x)
        logvar = self.logvar_layer(x)
        return mu, logvar

In [None]:
variational_encoder = VariationalEncoder(input_size= 13431, hidden_sizes=[256,100], latent_size=64, dropout=0.1, use_norm=True).to('mps')


In [None]:
# test the model
variational_encoder

VariationalEncoder(
  (model): Sequential(
    (0): Linear(in_features=13431, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=100, bias=True)
    (4): InstanceNorm1d(100, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
    (6): Dropout(p=0.1, inplace=False)
    (7): Linear(in_features=100, out_features=64, bias=True)
  )
  (mu_layer): Linear(in_features=100, out_features=64, bias=True)
  (logvar_layer): Linear(in_features=100, out_features=64, bias=True)
)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()