# Decoder

> A customisable pytorch variational decoder model.

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp Models.VariationalDecoder

In [None]:
#| export
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class VariationalDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout, use_norm):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.use_batch_norm = use_norm
        
        # create a list of layers
        layers = []

        # input layer
        layers.append(nn.Linear(self.input_size, self.hidden_size))
        layers.append(nn.LeakyReLU(0.2))
        if self.dropout > 0:
            layers.append(nn.Dropout(p=self.dropout))

        # hidden layers
        layers.append(nn.Linear(self.hidden_size, self.hidden_size))
        if self.use_batch_norm:
            layers.append(nn.InstanceNorm1d(self.hidden_size))
        layers.append(nn.LeakyReLU(0.2))
        if self.dropout > 0:
            layers.append(nn.Dropout(p=self.dropout))
        
        # output layer
        layers.append(nn.Linear(self.hidden_size, self.output_size))
        layers.append(nn.Sigmoid())

        # create the model using Sequential
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)



In [None]:
variational_decoder = VariationalDecoder(input_size=10, hidden_size=20, output_size=30, dropout=0.5, use_norm=True)

In [None]:
variational_decoder

VariationalDecoder(
  (model): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=20, out_features=20, bias=True)
    (4): InstanceNorm1d(20, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Dropout(p=0.5, inplace=False)
    (7): Linear(in_features=20, out_features=30, bias=True)
    (8): Sigmoid()
  )
)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()