# VAE

> Variational Autencoder model, takes a variational encoder and decoder as arguments.

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp Models.VAE

In [None]:
#| export
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        """
        A Varational Autoencoder (VAE) model.
        """
        super(VAE, self).__init__()
        self.encoder = encoder 
        self.decoder = decoder
    
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = (mu + eps*std)
        return z

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat,  mu, logvar
    

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()