# Variational Auto-Encoder (VAE)

Diederik P Kingma, and Max Welling. Auto-Encoding Variational Bayes. 2013. [https://arxiv.org/abs/1312.6114]

VAE differs from autoencoder as follows.
 - VAE is a probabilistic autoencoder. That is, even after learning is over, the output is partially determined by chance.
 - VAE is a generative autoencoder and can generate new samples, such as those sampled from the training dataset.  

Dataset: CIFAR10

Will implement the VAE with the convolutional neural network (CNN).

A VAE embeds the input features in the low dimensional latent space with the probabilistic encoder $q_{\phi}(z|x)$ and reconstruct the original input from the latent embedding with the probabilistic decoder $p_{\theta}(x|z)$.

The goal is to maximize the marginal likelihood of reconstructed input data $x^{(1)}, ..., x^{(N)}$ which is composed of a sum over the marginal likelihoods of individual datapoints.
$$\log p_{\theta}(x^{(1)}, ..., x^{(N)}) = \sum_{i=1}^{N}\log p_{\theta}(x^{(i)})$$

The individual marginal likelihood can be written as
$$
\begin{align*}
\log p_\theta (x^{(i)}) &= \mathbb{E}_{z \sim q_\phi (x^{(i)})}\left[\log p_\theta (x^{(i)})\right]  \\
                        &= \mathbb{E}_{z} \left[\log \frac{p_\theta (x^{(i)}|z)p_\theta (x)}{p_\theta (z|x^{(i)})}\right]\\
                        &= \mathbb{E}_{z} \left[\log \left( \frac{p_\theta (x^{(i)}|z)p_\theta (x)}{p_\theta (z|x^{(i)})}
                                \cdot \frac{q_\phi (z|x^{(i)})}{q_\phi (z|x^{(i)})} \right) \right]  \\
                        &= \mathbb{E}_{z} \left[\log  p_\theta (x^{(i)}|z) \right] 
                        - \mathbb{E}_{z} \left[\log \frac{q_\phi (z|x^{(i)})}{p_\theta (z)} \right] 
                        + \mathbb{E}_{z} \left[\log \frac{q_\phi (z|x^{(i)})}{p_\theta (z|x^{(i)})} \right]
                        \\
                        &= \mathbb{E}_{z} \left[\log  p_\theta (x^{(i)}|z) \right]
                        - D_{KL} \left(q_\phi (z|x^{(i)}||p_\theta (z)) \right)
                        + D_{KL} \left(q_\phi (z|x^{(i)}||p_\theta (z|x^{(i)})) \right) \\
                        &\ge \mathbb{E}_{z} \left[\log  p_\theta (x^{(i)}|z) \right]
                        - D_{KL} \left(q_\phi (z|x^{(i)}||p_\theta (z)) \right)
\end{align*}
$$

Thus the marginal likelihood can be lower bounded by the last equation, whichis also known as the *variational lower bound*.
The first term in the last equation is the log likelihood of the reconstruction from the decoder whereas the second term is the KL divergence of the posterior distribution of the latent embedding from its prior distribution.

In [14]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torchvision.utils import save_image

torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 100

img_shape = (3, 32, 32)
img_dim  = 3 * 32 * 32
latent_dim = 128

lr = 1e-3

n_epochs = 10

In [3]:
data_path = '../data'

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

loader_kwargs = {'num_workers': os.cpu_count()//2, 'pin_memory': True} 

train_data = datasets.CIFAR10(data_path, transform=transform, train=True, download=True)
test_data  = datasets.CIFAR10(data_path, transform=transform, train=False, download=True)

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, **loader_kwargs)
test_loader  = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, **loader_kwargs)

Files already downloaded and verified
Files already downloaded and verified


## Encoder
Given an input x, the encoder first maps into the hidden space as
$$h_e = ReLU(W_e · x)$$
with an affine transform $W_e$.
Then $h_e$ is mapped to low dimensional latent features $\mu$ and $\log \sigma^2$ respectively as
$$
\mu = W_\mu \cdot h_e \\
\log \sigma^2 = W_\sigma \cdot h_e
$$

In [5]:
class Encoder(nn.Module):
    """
        Produces the parameters of normal distribution q, 
        mean and log of variance.
    """
    def __init__(self, img_channels=3, feature_dim=32*32*32, latent_dim=latent_dim):
        super(Encoder, self).__init__()

        self.enc = nn.Sequential(
            nn.Conv2d(img_channels, 16, 3, padding=1),
            nn.LeakyReLU(True),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.LeakyReLU(True),
            nn.Flatten()
        )
        self.FC_mean = nn.Linear(feature_dim, latent_dim)
        self.FC_var = nn.Linear(feature_dim, latent_dim)
        
        
    def forward(self, x):
        x = self.enc(x)
        mean = self.FC_mean(x)
        log_var = self.FC_var(x) 
        
        return mean, log_var


## Decoder
Given an sampled latent z, the decoder first maps into the hidden space as
$$h_d = ReLU(W_d \cdot z)$$
with an affine transform $W_d$.
Then hd is reconstructed into the input image as
$$ x' = sigmoid(W_r \cdot h_d) $$

In [6]:
class Decoder(nn.Module):
    def __init__(self, img_channels=3, feature_dim=32*32*32, latent_dim=latent_dim):
        super(Decoder, self).__init__()
        self.decFC1 = nn.Linear(latent_dim, feature_dim)
        self.decConv1 = nn.ConvTranspose2d(32, 16, 3, padding=1)
        self.decConv2 = nn.ConvTranspose2d(16, img_channels, 3, padding=1)

        
    def forward(self, x):
        x = F.leaky_relu(self.decFC1(x))
        x = x.view(-1, 32, 32, 32)
        x = F.leaky_relu(self.decConv1(x))
        x = torch.sigmoid(self.decConv2(x))
        return x

## Reparameterization Trick
As taking a derivative of a random sampling is non-trivial, we use the reparameterization trick. In the VAE, it is assumed that $\mathcal{z}$ follows a Gaussian distribution $\mathcal{N}(\mu, \sigma ^2)$. Taking a derivative with respect to $\mathcal{N}(\mu, \sigma ^2)$ directly is non-trivial. Thus we introduce an auxiliary variable $\epsilon$ to make it available to take a derivative and use the gradient descent.
$$
z = \mu + \sigma \odot \epsilon \\
\epsilon ∼ \mathcal{N}(0, 1)
$$
This function returns sampled $\mathcal{z}$ from the given mu and log_var which corresponds to $\mu$ and $\log \sigma ^2$, respectively.

In [7]:
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparam(self, mean, var):
        epsilon = torch.randn_like(var)
        z = mean + var * epsilon
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparam(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [8]:
encoder = Encoder()
decoder = Decoder()

model = VAE(Encoder=encoder, Decoder=decoder).to(device)

## Loss Functions
Our goal is maximizing the right hand side of Equation 

$$
\max _{\phi, \theta} \mathbb{E}_{z} \left[\log p_\theta (x^{(i)}|z)\right] - D_{KL}\left( q_\phi (z|x^{(i)}) || p_\theta (z)\right) \\ 
$$
which is equivalent to
$$
\min _{\phi, \theta} \mathbb{E}_{z} \left[-\log p_\theta (x^{(i)}|z)\right] + D_{KL}\left(q_\phi (z|x^{(i)}) || p_\theta (z)\right) \\ 
$$
Assuming that $\log p_\theta (x^{(i)}|z)$ follows Bernoulli distribution, the negative log likelihood becomes the binary cross entropy. For the KL divergence, we assume that the prior distribution of $\mathcal{z}$ is the standard normal distribution. According the the Appendix B of the [Kingma et al., 2013], the KL divergence term becomes as below
$$
D_{KL}\left( q_\phi (z|x^{(i)}) || p_\theta (z) \right) = -\frac{1}{2} \sum_{i=1}^{J} \left(1+\log \sigma _j^2 -\mu _j^2 - \sigma _j^2 \right)
$$

In [9]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [10]:
saved_dir = 'vae_images'
os.makedirs(saved_dir, exist_ok= True)

In [11]:
loss_list = []

for epoch in range(n_epochs):
    for batch_idx, data in enumerate(train_loader):
        img, label = data
        img = img.view(batch_size, *img_shape)
        img = img.to(device)

        img_hat, mean, log_var = model(img)
        loss = loss_function(img, img_hat, mean, log_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

    # ===== save images and print logs =====
    save_image(img_hat.data[:25], f"{saved_dir}/{epoch+1}.png", nrow=5, normalize=True)
    print(f'epoch: {epoch+1}/{n_epochs}, loss: {loss.item() :.4f}')

epoch: 1/10, loss: -6026667.0000
epoch: 2/10, loss: -5808132.5000
epoch: 3/10, loss: -5369639.0000
epoch: 4/10, loss: -6513082.5000
epoch: 5/10, loss: -6273664.0000
epoch: 6/10, loss: -5821231.5000
epoch: 7/10, loss: -5952860.0000
epoch: 8/10, loss: -6387846.0000
epoch: 9/10, loss: -5942550.5000
epoch: 10/10, loss: -6444833.5000
