# Background on VAEs

![vae](figures/vae.png)

$$
\pmb{z} \sim \mathcal{N}(\pmb{\mu}, \pmb{\sigma}) \Longleftrightarrow \pmb{z} = \pmb{\mu} + \pmb{\sigma} \odot \pmb{\varepsilon}, \quad \pmb{\varepsilon} \sim \mathcal{N}(\pmb{0}, \pmb{I})
$$

$$
\mathbb{D}_{\text{KL}}\left( q_{\phi}(\pmb{z} \mid \pmb{x}^i) \Vert p_{\theta}(\pmb{z} \mid \pmb{x}^i) \right) = \int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^{i}) \log \frac{q_{\phi} (\pmb{z} \mid \pmb{x}^i)}{p_{\theta}(\pmb{z} \mid \pmb{x}^i)} \, d\pmb{z} \\ 
= \int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^i) \log{q_{\phi}(\pmb{z} \mid \pmb{x}^i)} \, d\pmb{z} - \int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^i) \log{p_{\theta}(\pmb{x}^i \mid \pmb{z})} \, d\pmb{z}- \int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^i) \log{ p_{\theta}(\pmb{z}) \, d\pmb{z}}+ \int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^i) \log{p_{\theta}(\pmb{x}^i) \, d\pmb{z}} = \\
\mathbb{D}_{\text{KL}}\left(q_{\phi}(\pmb{z} \mid \pmb{x}^i) \Vert p_{\theta}(\pmb{z}) \right)- \mathbb{E}_{q_{\phi}(\pmb{z} \mid \pmb{x}^i)} \left[ \log p_{\theta}(\pmb{x}^i \mid \pmb{z}) \right] + \log p_{\theta}(\pmb{x}^i)
$$

where we have the Bayes identity $p_{\theta}(\pmb{z} \mid \pmb{x}^{i}) = p_{\theta}(\pmb{x}^i \mid \pmb{z}) p_{\theta}(\pmb{z}) / p_{\theta}(\pmb{x}^i)$ and $\int_{\pmb{z}} q_{\phi}(\pmb{z} \mid \pmb{x}^i) \, d\pmb{z} = 1$.
Rearranging, we have:

$$
\log \, p_{\theta}(\pmb{x}^i) = \mathbb{D}_{\text{KL}}\left( q_{\phi}(\pmb{z} \mid \pmb{x}^i) \Vert p_{\theta}(\pmb{z} \mid \pmb{x}^i) \right)+ \underset{\texttt{ELBO}}{\underbrace{\mathbb{E}_{q_{\phi}(\pmb{z} \mid \pmb{x}^i)} \left[ \log p_{\theta}(\pmb{x}^i \mid \pmb{z}) \right] - \mathbb{D}_{\text{KL}}\left(q_{\phi}(\pmb{z} \mid \pmb{x}^i) \Vert p_{\theta}(\pmb{z}) \right)}}
$$

Why is it called ELBO?
$$
\log p_{\theta}(\pmb{x}^i) \geq \mathbb{E}_{q_{\phi}(\pmb{z} \mid \pmb{x}^i)} \left[ \log p_{\theta}(\pmb{x}^i \mid \pmb{z}) \right] - \mathbb{D}_{\text{KL}}\left(q_{\phi}(\pmb{z} \mid \pmb{x}^i) \Vert p_{\theta}(\pmb{z}) \right)
$$


In [None]:
from __future__ import division, absolute_import, print_function
import os
import numpy as np
from matplotlib import pyplot as plt
from functools import partial


import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True}
batch_size = 128
epochs = 10
log_interval = 10
latent_dim = 50

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, latent_dim)
        self.fc22 = nn.Linear(400, latent_dim)
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def loss_function_vanilla(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [None]:
def train(epoch, loss_function):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch, loss_function, tag):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/{}/reconstruction_'.format(tag) + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
tag = 'vanilla'
for epoch in range(1, epochs + 1):
    train(epoch, loss_function_vanilla)
    test(epoch, loss_function_vanilla, tag)
    with torch.no_grad():
        sample = torch.randn(64, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                       'results/{}/sample_'.format(tag) + str(epoch) + '.png')

In [None]:
def get_KLD_weights(model):
  def KLD(mu, logvar):
    with torch.no_grad():
      return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(),
                             dim=0,
                             keepdim = True)
  
  model.eval()
  total_KLD = np.zeros((1,latent_dim), dtype=np.float32)
  with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
      data = data.to(device)
      recon_batch, mu, logvar = model(data)
      batch_KLD = KLD(mu, logvar)
      total_KLD += batch_KLD.cpu().numpy()
      total_KLD /= 2.0

  return total_KLD, torch.norm(model.fc3.weight.data, dim=0).cpu().numpy()
        

In [None]:
t_KLD, weights = get_KLD_weights(model)

In [None]:
plt.bar(np.arange(latent_dim), t_KLD.squeeze())
plt.show()
plt.bar(np.arange(latent_dim), weights)
plt.show()

In [None]:
def loss_function_beta_scheduled(recon_x, x, mu, logvar, epoch):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    beta = 1.0 * epoch/epochs
    
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD

In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
tag = 'beta_scheduled'
for epoch in range(1, epochs + 1):  
    loss_function_beta = partial(loss_function_beta_scheduled, epoch=epoch)
    train(epoch, loss_function_beta)
    test(epoch, loss_function_beta, tag)
    with torch.no_grad():
        sample = torch.randn(64, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                       'results/{}/sample_'.format(tag) + str(epoch) + '.png')

In [None]:
t_KLD, weights = get_KLD_weights(model)
plt.bar(np.arange(latent_dim), t_KLD.squeeze())
plt.show()
plt.bar(np.arange(latent_dim), weights)
plt.show()

![sfb](figures/sfb2.png)

In [None]:
def loss_function_soft_bits(recon_x, x, mu, logvar, gamma_factor):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=0)

    return BCE/batch_size + torch.sum(gamma_factor * KLD)

In [None]:
def train(epoch, prev_gamma=None):
    def KLD(mu, logvar):
        with torch.no_grad():
            return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(),
                               dim=0,
                               keepdim = True)
    model.train()
    train_loss = 0
    lambda_per_dim = 0.5
    information_threshold = 0.1
    gamma_rate = 0.1
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)

        with torch.no_grad():
            KL_per_dim = KLD(mu, logvar)

            factor_1 = torch.le(KL_per_dim,
                                (1-information_threshold) * lambda_per_dim)
            factor_1 = torch.where(factor_1,
                                   torch.ones_like(KL_per_dim) * (1. - gamma_rate),
                                   torch.ones_like(KL_per_dim))
            factor_2 = torch.ge(KL_per_dim,
                                (1+information_threshold) * lambda_per_dim)
            factor_2 = torch.where(factor_2,
                                   torch.ones_like(KL_per_dim) * (1. + gamma_rate),
                                   torch.ones_like(KL_per_dim))
            factor = factor_1 * factor_2
            gamma_factor = factor.to(device) * prev_gamma
            gamma_factor = torch.clamp(gamma_factor, 0., 1.)
        
        
        loss = loss_function_soft_bits(recon_batch, data, mu, logvar, 
                                       gamma_factor)
#         import pdb; pdb.set_trace()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()*batch_size / len(data)) + '\tGamma: ',gamma_factor.norm().item())
        prev_gamma = gamma_factor

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    return gamma_factor

In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
tag = 'soft_bits'
# prev_gamma = torch.ones((1, latent_dim)).to(device)
# prev_gamma *= 0.01
prev_gamma = torch.rand((1, latent_dim)).to(device)
prev_gamma.require_grad = False
for epoch in range(1, epochs + 1):  
    next_gamma = train(epoch, prev_gamma)
    soft_loss = partial(loss_function_soft_bits, gamma_factor=prev_gamma)
    test(epoch, soft_loss, tag)
    with torch.no_grad():
        sample = torch.randn(64, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                       'results/{}/sample_'.format(tag) + str(epoch) + '.png')
    prev_gamma = next_gamma

In [None]:
t_KLD, weights = get_KLD_weights(model)
plt.bar(np.arange(latent_dim), t_KLD.squeeze())
plt.show()
plt.bar(np.arange(latent_dim), weights)
plt.show()

## References

@inproceedings{kingma2014semi,
  title={Semi-supervised learning with deep generative models},
  author={Kingma, Durk P and Mohamed, Shakir and Rezende, Danilo Jimenez and Welling, Max},
  booktitle={Advances in neural information processing systems},
  pages={3581--3589},
  year={2014}
}

@inproceedings{dehban2017deep,
  title={A deep probabilistic framework for heterogeneous self-supervised learning of affordances},
  author={Dehban, Atabak and Jamone, Lorenzo and Kampff, Adam R and Santos-Victor, Jos{\'e}},
  booktitle={2017 IEEE-RAS 17th International Conference on Humanoid Robotics (Humanoids)},
  pages={476--483},
  year={2017},
  organization={IEEE}
}

@article{higgins2017beta,
  title={beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework.},
  author={Higgins, Irina and Matthey, Loic and Pal, Arka and Burgess, Christopher and Glorot, Xavier and Botvinick, Matthew and Mohamed, Shakir and Lerchner, Alexander},
  journal={ICLR},
  volume={2},
  number={5},
  pages={6},
  year={2017}
}

@article{bowman2015generating,
  title={Generating sentences from a continuous space},
  author={Bowman, Samuel R and Vilnis, Luke and Vinyals, Oriol and Dai, Andrew M and Jozefowicz, Rafal and Bengio, Samy},
  journal={arXiv preprint arXiv:1511.06349},
  year={2015}
}

@inproceedings{
he2018lagging,
title={Lagging Inference Networks and Posterior Collapse in Variational Autoencoders},
author={Junxian He and Daniel Spokoyny and Graham Neubig and Taylor Berg-Kirkpatrick},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=rylDfnCqF7},
}