# Variational Autoencoder

As suggested by the title, VAE is extended from autoencoders. In autoencoders, we have an encoder which maps the input $x$ to a latent variable $z$ and a decoder then reconstruct $\hat{x}$ from $z$. The learned feature $z$ could serve as a representation of the input since most important information is already included in $z$ if $\hat{x}$ is close to $x$.

<img src="src/autoencoder.png" width="30%" height="30%" />

In VAE, we add probability to the autoencoder and define the density function of $x$ based on $z$ in the form of:  

$$
\begin{equation}
p_\theta(x) = \int p_\theta(z)p_\theta(x|z)dz
\end{equation}
$$ 
In practice, we choose a simple distribution like Gaussian for $p(z)$ and use a neural network to represent the complex conditional distribution $p(x|z)$. We learn the parameters $\theta$ by maximizing the likelihood of the training data. 

Although the integration over $z$ is intractable, we could optimize a lower bound of $p(x)$ with the help of the encoder and bayes rule. The encoder provides a conditional probability distribution $q_\phi (z|x)$. As a result, the logarithm of the likelihood of training data could be written as:

$$
\begin{align}
log p_\theta(x^{i}) &=  E_{z\sim q_\phi(z|x^{i})} \big[log \frac{p_\theta(z)\ p_\theta(x^{i}|z)}{p_\theta(z|x^i))} \frac{q_\phi(z|x^{i})}{q_\phi(z|x^i}\big]  \\
 &= E_z\big[log p_\theta(x^i|z)\big] - E_z\big[\frac{q_\phi(z|x^i)}{p_\theta(z)}\big] + E_z\big[\frac{q_\phi(z|x^i)}{p_\theta(z|x^i)}\big] \\
 &= E_z\big[logp_\theta(x^i|z)\big] - D_{KL}(q_\phi(z|x^i) \| p_\theta(z)) + D_{KL}(q_\phi(z|x^i) \|p_\theta(z|x^i) )
\end{align}
$$

Although the third term $D_{KL}(q_\phi(z|x^i) \|p_\theta(z|x^i)$ in equation $4$ is still intractable, it's always bigger than 0. We ignore it and optimize the first two terms which is a lower bound of $log p_\theta(x^i)$. 

The maximization of $logp_\theta(x^i)$ could be factorized into maximization of $E_z\big[logp_\theta(x^i|z)\big]$ (reconstruct input as best as possible) and minimization of $D_{KL}(q_\phi(z|x^i) \| p_\theta(z))$ (make approximate posteior probability distribution close to prior ) 

Then the architecture of VAE is like below:

<img src="src/vae.png" width="30%" height="30%" />

As we can see, the output of the encoder and decoder are nolonger deterministic $z$ or $x^i$ but are probability distributions! Based on probability distributions, we sample $z$ and $x^i$. Besides that, with reparameterization trick, we make sampling also differentiable. Okay! Let's get through a project of VAE on MNIST to better understand how it works.

In [None]:
# import packages
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from easydict import EasyDict as edict
import os

In [None]:
# initialize parameters and dataset
args = edict()
args.batch_size = 128
args.epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data/MNIST', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data/MNIST', train=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True)

In [None]:
# define VAE
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.en1 = nn.Linear(784, 400)
        self.en2_mu = nn.Linear(400, 20)
        self.en2_logvar = nn.Linear(400, 20)
        self.de1 = nn.Linear(20, 400)
        self.de2 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.en1(x))
        return self.en2_mu(h1), self.en2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.de1(z))
        return torch.sigmoid(self.de2(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    


Here we assume the prior of $z$ is a Gaussian distribution and the conditional probability distribution of the decoder is a bernoulli distribution. That's why the encoder has two outputs and the decoder has only one. Reparameterize function simulates the sampling based on $\mu + randn * std$ that makes $z$ also differentiable. 

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

As introduced above, there are two terms for optimization which are BCE and KLD in the loss function defined above. Actually the bernoulli distribution might be a little inappropriate here because the inputs in MNIST daataset are not binary. However, since most values in the input are 0 or 1, we still think of them as approximately binary. Then the binary cross entropy between the inputs and the reconstructed bernoulli probability $y$ is indeed the first term in equation $4$. 

$$ 
\begin{align}
logp(x|z) &= log(x^y(1-x)^{1-y})  \\
&= ylog(x) + (1-y)log(1-x) \\
\end{align}
$$

Then let's look into the second term in equation $4$ which is the KL divergence between $q(z|x)$ and $p(z)$. Since both of them are Gaussian distributions, we assume $q \sim \mathcal{N}(\mu_1, \sigma_1)$ and $p \sim \mathcal{N}(\mu_2,\sigma_2)$. Then it could be written as (The derivation between two gaussian distributions are inducted later, here we just show the conclusion):
$$
\begin{align}D_{KL}(q_\phi(z|x)\|p(z)) & = -\frac{1}{2} + log(\frac{\sigma_2}{\sigma_1}) + \frac{\sigma_1^2 + (\mu_1-\mu_2)^2}{2\sigma_2^2}  \\&= -\frac{1}{2}(1 + log\sigma_1 - \sigma_1^2-\mu_1^2) \qquad where \ \mu_2 = 0, \sigma_2 = 1
\end{align}
$$

Then, we can see that they match exactly with BCE and KLD in the loss function. The derivation for KL divergence between two gaussian distributions is shown below. You may skip it if you just want to know the basic idea of VAE. 

$$
p \sim \mathcal{N}(\mu_1, \sigma_1), q \sim \mathcal{N}(\mu_2, \sigma_2) \\\begin{align}D_{KL}(p\|q) & = -\int p(x)log(q(x))dx + \int p(x) log(p(x))dx  \\&= -\int p(x) log\frac{1}{(2\pi\sigma_2^2)^{1/2}}e^{-\frac{(x-\mu_2)^2}{2\sigma_2^2}}dx + \int p(x)log\frac{1}{(2\pi\sigma_1^2)^{1/2}}e^{-\frac{(x-\mu_1)^2}{2\sigma_1^2}}dx  \\&= \frac{1}{2}log(2\pi\sigma_2^2) + \int p(x)\frac{(x-\mu_2)^2}{2\sigma_2^2}dx - \frac{1}{2} log(2\pi\sigma_1^2) -\int p(x)\frac{(x-\mu_1)^2}{2\sigma_1^2}dx \\&= log\frac{\sigma_2}{\sigma_1} + \frac{1}{2\sigma_2^2} \big[\int p(x)x^2dx - \int p(x)2x\mu_2dx + \int p(x)\mu_2^2dx \big] - \frac{1}{2}  \\&= log\frac{\sigma_2}{\sigma_1}  + \frac{\sigma_1^2 + \mu_1^2 -2\mu_1\mu_2 + \mu_2^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
$$


In [None]:
# Then we create the model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Define train and test function
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'VAE_result/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


In [None]:
# Do the trainning and see results
for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'VAE_result/sample_' + str(epoch) + '.png')


Then we can see the reconstruction and sampling results in VAE_result directory!