# Variational Autoencoder 

Autoencoder (AE) is a method to map high-dimensional signal into lower-dimensional latent space [1]. However, for signal creation it is not strightforward to sample a low-dimensional signal from latent space, and generate a high-dimensional signal that follows the distribution of training data.  

## Autoencoder

A simple autoencoder architecture consists of three major components:
- Encoder 
- Latent Representation 
- Decoder

Different from AE which maps the input to a vector, variatoinal autoencoder (VAE) restricts the latent representation only as a distribution. 

## Evidence Lower Bound 


The Evidence Lower Bound (ELBO) is a concept primarily used in the context of variational inference, a method for approximating complex probability distributions. In probabilistic modeling, we often want to infer the posterior distribution of latent variables given observed data. However, calculating the true posterior distribution is often unfeasible. Variational inference approximates the posterior distribution with a simpler distribution chosen from a parameterized function, such as a Gaussian distribution. 

The ELBO serves as a lower bound for the log marginal likelihood of the data. By maximizing the ELBO with respect to the parameters of the approximate posterior distribution, we indirectly maximize the log marginal likelihood. This is because the ELBO is derived from the Kullback-Leibler (KL) divergence between the approximate posterior and the true posterior, and maximizing the ELBO minimizes this divergence.

## Varitaional Autoencoder 

In [None]:
import numpy as np
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader 
#from torchvision.utils import save_image, make_grid
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

datapath = r'C:\Users\User\Documents\repos\data'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 784
hidden_dim = 200
latent_dim = 20
num_epochs = 10
batch_size = 32
lr = 3e-4 

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    #transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(datapath, train=True, transform=transform, download=False)
test_dataset = datasets.MNIST(datapath, train=False, transform=transform, download=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
        self.relu = nn.ReLU() 
        

    def encoder(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        mu = self.fc_mean(x)
        logvar = self.fc_var(x)
        return mu, logvar 


    def decoder(self, x):
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = torch.sigmoid(self.fc5(x))
        return x 
    
    
    def reparameterize(self, mean, var):
        epsilon = torch.randn_like(var).to(device)
        z = mean + epsilon * var
        return z 
    
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, torch.exp(0.5 * logvar)) 
        recon = self.decoder(z)
        return recon, mu, logvar 
    

In [None]:
model = VAE(input_dim, hidden_dim, latent_dim).to(device) 
optimizer = optim.Adam(model.parameters(), lr=lr)
#loss_fn = nn.BCELoss(reduction="sum")|

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + KLD

In [None]:
for epoch in range(num_epochs):
    for idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, input_dim)
        x = x.to(device)

        optimizer.zero_grad()
        x_hat, mean, log_var = model(x)
        
        loss = loss_function(x, x_hat, mean, log_var)
        loss.backward()
        optimizer.step()


## Reference 

- https://kvfrans.com/variational-autoencoders-explained/
- https://www.jeremyjordan.me/variational-autoencoders/ 