# Lagrangrian Variational Autoencoder (Draft)

[[source](https://arxiv.org/abs/1806.06514)]
[[TensorFlow](https://github.com/ermongroup/lagvae)]

Note: The code below is functional. However, the explanation here is still in draft form.

This is implemented with PyTorch and Lightning. The code is based on the [paper](https://arxiv.org/abs/1806.06514) and [tensorflow code](https://github.com/ermongroup/lagvae) by Ermon et al.

The most notable change from a basic VAE is that instead of directly optimizing the loss, they are optimizing the [Lagrangian dual problem](https://en.wikipedia.org/wiki/Duality_(optimization)#Dual_problem). A tutorial on convex optimization involving the Lagrangian dual problem can be found [here](http://eceweb.ucsd.edu/~gert/ECE273/hindiTutorial2.pdf). If you've taken vector calculus, you may have seen a form of the Lagrangian before as Lagrange multipliers. The idea is similar, but now more generally applicable.

The basic VAE gives us a process by which to find a probability distribution $p(x, z)$, where $x$ are observations and $z$ are latent variables. This is done by factoring $p(x, z)$ to $p(x|z)p(z)$. The distribution $p(z)$ is specified to be normal a priori, while $p(x|z) = p(x|z, \theta)$ is a parameterized distribution that is learned through the decoder. To make the process of finding $p(x)$ tractable, we introduce the variational distribution $q(z|x)$, which approximates $p(z|x)$. We can similarly define a distribution $q(x, z) = q(z|x)q(x)$. Ideally, the variational distribution $q(z|x) = p(z|x)$, and so we would like $p(x, z) = q(x, z)$. If the two joint distributions are equal, then their marginals will be equal.

This provides another way of training the model, which is to enforce consistency. The basic VAE is trained to maximize the evidence lower bound. As shown in the paper above, this is equivalent to the optimization problem

$$ \min_\theta \text{KL}(q(x, z | \theta) \Vert p(x, z | \theta))$$

where KL is the Kullback-Leibler divergence. Other variational autoencoders can be shown to be equivalent to minimization problems using different divergences. With this in mind, we let $\mathcal{D}$ be a vector of divergences between probabilities. We require that $\mathcal{D} = 0$ (so that $\mathcal{D}_i = 0$ for each $i$) if and only if $p(x, z) = q(x, z)$. This allows us to use $\mathcal{D}$ to enforce consistency. As this depends on the parameters for the distributions, we have that $\mathcal{D}$ depends on $\theta$, i.e. $\mathcal{D} = \mathcal{D}(\theta)$. I'll write the latter to emphasize the connection to theta.

However, having consistent marginals does not guarantee that there is any meaningful relationship between $x$ and $z$. The paper tackles this latter problem by using an extra function $f(\theta)$. The specific function that is chosen is based on our preferences between consistent distributions. The choice in the code below for $f$ is one that encodes a preference for mutual information. Our optimization problem has now become

$$\min_\theta f(\theta) \text{ where } \mathcal{D}(\theta) = 0.$$

This has the dual problem

$$\max_{\lambda \geq 0}\min_\theta \left[f(\theta)+\lambda^T\mathcal{D}(\theta)\right],$$

where $\theta$ includes choices where $\mathcal{D}(\theta) \geq 0$. This second equation comes from considering the convex conjugate of $f$. in the dual space, which gives rise to the name 'dual problem.' The important factor is that the same parameter $\theta$ is the solution for both equations when we have strong duality.

We do need to make a slight adjustment, however. We will not in practice have $\mathcal{D} = 0$ for all models as a result of having finite capacity. Instead, we choose a 'consistency constraint' vector $\epsilon$ that limits the size of $\mathcal{D}$. In this situation, we optimize

$$\min_\theta f(\theta) \text{ where } \mathcal{D}_i \leq \epsilon_i.$$

which has the corresponding dual equation

$$\max_{\lambda \geq 0}\min_\theta \left[f(\theta)+\lambda^T(\mathcal{D}-\epsilon)\right].$$

## Code
### Variational Encoder

In [1]:
import torch
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class VariationalEncoder(nn.Module):
    '''
    Encodes an observation to the latent variable parameters.
    
    :param z_dim: (int) Dimension of latent variable
    '''
    def __init__(self, z_dim=2):
        super(VariationalEncoder, self).__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28**2, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU()
        )
        self.mean = nn.Linear(128, z_dim)
        self.logvar = nn.Linear(128, z_dim)

    def forward(self, input):
        x = self.network(input)
        
        mean = self.mean(x)
        scale = (self.logvar(x)/2).exp()
        return mean, scale

### Variational Decoder

In [3]:
class VariationalDecoder(nn.Module):
    '''
    Decodes the latent variable to an observation.
    
    :param z_dim: (int) Dimension of latent variable
    '''
    def __init__(self, z_dim=2):
        super(VariationalDecoder, self).__init__()
        self.z_dim = z_dim
        self.network = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, 28**2),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.network(z)
        return x.view((-1,1,28,28))

### Lagrangian Variational Autoencoder

In [4]:
class LagrangianVAE(nn.Module):
    '''
    Implements a Lagrangian VAE. This differs from a other VAEs as it is optimized
    according to a dual optimization problem.
    
    :param encoder: (nn.Module) Encoder network. Should return mean, scale.
    :param decoder: (nn.Module) Decoder network. Should return the observation reconstruction.
    :param mi: (float) Mutual information parameter. A positive value places emphasis on
        mutual information under p(x, z). A negative value emphasizes mutual information
        under q(x, z).
        
        More directly, a positive value emphasizes the reconstruction error while a negative
        value emphasizes the latent error.
    :param e: ((float)) Relaxed consistency constraint. This is used to make the optimization
        problem tractible.
    :param lmbda: ((float)) Starting value for lambda.
    :param mmd_scale: (float) How much to scale the maximum mean discrepency.
    :param optimize_lambda: (Bool) Whether lambda will be optimized over. Setting to False
        will fix lambda at its starting value.
    '''
    def __init__(self,
                 encoder,
                 decoder,
                 mi,
                 e = (1., 1.),
                 lmbda = (1., 1.)):
        super(LagrangianVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

        #The feasability parameter. (See section 5.2 in the paper)
        self.e = e
        #Mutual information parameter
        self.mi = mi

        self.L1 = nn.Parameter(torch.tensor(lmbda[0]))
        self.L2 = nn.Parameter(torch.tensor(lmbda[1]))

    def forward(self, input):
        mean, scale = self.encoder(input)

        eps = torch.randn_like(mean)
        z = mean+eps*scale

        return self.decoder(z)

    def generate_similar(self, input, noise=True):
        '''
        Generates similar images to input.

        :param input: (torch.Tensor) Original observation
        :param noise: (Bool) Whether to inject noise into the reconstruction.
            By default, z = N(mu, Sigma). If noise is False, then z = mu.
        '''
        mean, scale = self.encoder(input)
        if noise:
            eps = torch.randn_like(mean)
            z = mean+eps*scale
        else:
            z = mean
        return self.decoder(z)

    def child_parameters(self):
        '''
        Construct generator that returns all parameters but lambda.
        '''
        for network in [self.encoder, self.decoder]:
            for parameter in network.parameters():
                yield parameter

In [10]:
def mmd_loss(z):
    '''
    Compute maximum mean discrepency for the latent variable z.

    :param z: (torch.Tensor)
    :return: (torch.Tensor) 
    '''
    true_sample = torch.randn_like(z)
    return 10000.*compute_mmd(z, true_sample) # Scale is built in

def elbo_loss(z, eps, scale):
    '''
    Compute ELBO loss for the encoding.

    :param z: (torch.Tensor) Latent variable
    :param eps: (torch.Tensor) Noise used to generate z. Should have
        eps ~ N(0, I). 
    :param scale: (torch.Tensor) Scale used to generate z.
    :return: (torch.Tensor) evidence lower bound.
    '''
    return torch.sum(z.pow(2)/2-eps.pow(2)/2-scale.log(), 1)

#Computes maximal mean divergence (mmd)
def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean()
    return mmd

def compute_kernel(x, y):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    x = x.unsqueeze(1) # (x_size, 1, dim)
    y = y.unsqueeze(0) # (1, y_size, dim)
    tiled_x = x.expand(x_size, y_size, dim)
    tiled_y = y.expand(x_size, y_size, dim)
    kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim)
    return torch.exp(-kernel_input) # (x_size, y_size)

## Optimizing the network

### Solving a max-min problem
The most notable change from a basic VAE is that instead of directly optimizing the loss, we are attempting to solve the Lagrangian dual problem. While we are attempting to minimize the objective with respect to $\theta$, we are trying to maximize it with respect to $\lambda$. The issue that this causes is that the optimizers in pytorch are designed to minimize only. It turns out there is an easy fix, though: after calling backwards, change the stored gradient to be negative. (Linearity of differentation is great for a lot of reasons.) Alternatively, you can use `(-loss).backward(retain_graph=True)`. As pytorch graphs are dynamic, we need to explicitely retain the graph. The reason I didn't use this is that it requires us to zero the gradients and call `loss.backward()` again to update the other parameters. 

We can use one optimizer for both sets of parameters if we take care to update the gradient first. We use an optimizer for each as in the original code, as it allows for greater control over our parameter updates.

In [11]:
from torch.nn import functional as F
import torch.utils.data as data
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

training_set = datasets.MNIST(root='./data/', train=True, download=False, transform=ToTensor())
batch_size = 100
training_loader = data.DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)

In [12]:
def make_lagVAE(z_size, mi, e, lmbda = (1., 1.)):
    encoder = VariationalEncoder(z_size)
    decoder = VariationalDecoder(z_size)
    return LagrangianVAE(encoder, decoder, mi, e, lmbda)

In [23]:
def train_lagVAE(model, data_loader, optimizer, lambda_optim, device='cpu'):
    L1, L2 = model.L1, model.L2
    e1, e2 = model.e
    mi = model.mi
    
    average_recon = 0.
    average_latent = 0.
    for n_batch, (x, _) in enumerate(data_loader):
        optimizer.zero_grad()
        lambda_optim.zero_grad()

        x = x.to(device=device)
        mean, scale = model.encoder(x)

        eps = torch.randn_like(mean)
        z = mean+eps*scale

        x_pred = model.decoder(z)

        #compute -log p(x|z)
        nll = F.mse_loss(x, x_pred, reduction='sum') / len(x)

        #latent errors
        mmd = mmd_loss(z)
        elbo = elbo_loss(z, eps, scale).mean()

        if mi <= 0:
            loss = L1 * nll + (L1 - mi) * elbo + L2 * mmd - L1 * e1 - L2 * e2
        else:
            loss = (L1 + mi) * nll + L1 * elbo + L2 * mmd - L1 * e1 - L2 * e2

        loss.backward()
        
        optimizer.step()
        update_lambda(L1, L2, lambda_optim)

        average_latent = 0.9 * average_latent + 0.1 * float(mmd + elbo)
        average_recon = 0.9 * average_recon + 0.1 * float(nll)

        if n_batch % 30 == 0:
            print('\t'.join([
                '\rnll: {0:.4f}'.format(average_recon),
                'lat: {0:.4f}'.format(average_latent),
                'L1: {0:.3f}'.format(L1),
                'L2: {0:.3f}'.format(L2)]),
                end=''
            )

def update_lambda(L1, L2, optim):
    # Updates lambda grad to maximize and clamps after update
    L1.grad *= -1.
    L2.grad *= -1.
    optim.step()
    with torch.no_grad():
        L1.clamp_(0.001, 100)
        L2.clamp_(0.001, 100)

In [25]:
e = (86., 5.)
mi = 1
lmbda = (1., 1.)
z_dim = 16

model = make_lagVAE(z_dim, mi, e, lmbda)
model = model.to(device)

optim = torch.optim.Adam(model.child_parameters(), lr=1e-4)
lambda_optim = torch.optim.RMSprop([model.L1, model.L2], lr=0.0001)

epoch_summary_text = '\rEpoch {0}/{1}'
loss_summary_text = '{0} loss: {1:.5}'

epochs = 10
for epoch in range(epochs):
    train_lagVAE(model, training_loader, optim, lambda_optim, device)

    #Create summary loss
    print('\tEpoch {0}/{1}'.format(epoch+1, epochs))

nll: 53.3741	lat: 23.7278	L1: 0.965	L2: 1.0633	Epoch 1/10
nll: 52.2113	lat: 22.5336	L1: 0.905	L2: 1.118	Epoch 2/10
nll: 50.5696	lat: 23.8000	L1: 0.845	L2: 1.174	Epoch 3/10
nll: 49.7108	lat: 25.9649	L1: 0.785	L2: 1.230	Epoch 4/10
nll: 48.7753	lat: 25.3169	L1: 0.724	L2: 1.285	Epoch 5/10
nll: 47.9363	lat: 23.2921	L1: 0.664	L2: 1.340	Epoch 6/10
nll: 48.1585	lat: 25.6364	L1: 0.604	L2: 1.396	Epoch 7/10
nll: 48.5680	lat: 24.7749	L1: 0.544	L2: 1.451	Epoch 8/10
nll: 48.0952	lat: 29.3923	L1: 0.484	L2: 1.508	Epoch 9/10
nll: 48.0837	lat: 24.2637	L1: 0.424	L2: 1.563	Epoch 10/10


### Training

## Choosing model parameters

If $\lambda$ is fixed, then we are no longer solving the dual problem. However, fixed choices of $\lambda$ can correspond to other VAEs. In the code above, we can disable optimizating $\lambda$. This effectively makes $\lambda$ a hyperparameter that must be chosen by the user, and turns the Lagrangian VAE into an InfoVAE. However, the InfoVAE does not guarantee that our solution will be consistent according to our choice of $\epsilon$. The benefit of the Lagrangian VAE is that by optimizing $\lambda$, we enforce the constraints. If we have $\mathcal{D}_i > \epsilon_i$, then $\lambda_i$ will grow increasingly positive until the constraint is satisfied. This forces the optimization problem to weigh the violation of the consistency more heavily.

The paper goes into detail about how a number of previous VAE models can be recovered by specific choices of parameters.

In [None]:
use_lagvae = True
e = (86., 5.)
mi = 1
lmbda = (1., 1.)
z_dim = 2

dataset = 'mnist'
lagVAE = make_lagVAE(z_dim, mi, e, lmbda, optimize_lambda=use_lagvae)

model = VAE_template(lagVAE, dataset, batch_size=1000)

trainer = Trainer(gpus=[0],
                  max_nb_epochs=300,
                  checkpoint_callback=checkpoint_callback)

trainer.fit(model)
torch.save(model.state_dict(), './models/lagvae.pkl')

## Examples

### Reconstruction of images
Below are image reconstructions, similar to my [VAE notebook](./basic_VAE.ipynb). A notable difference though is that these images are much sharper. The noisy reconstructions also have greater consistency.

In [None]:
import matplotlib.pylab as plt
import numpy as np

image_dir = './images/'
image_desc = 'lagvae.png'
title_desc = 'with LagrangianVAE'

def image_plot(axis, image, **kwargs):
    axis.imshow(image.view(28,28), cmap='Greys', **kwargs)

In [None]:
#Set up picture
ax_settings = {'aspect':'equal', 'xticklabels':[], 'yticklabels':[], 'xticks':[], 'yticks':[]}
fig, ax = plt.subplots(5, 6, subplot_kw=ax_settings, figsize=(9, 7))
if dataset.lower() == 'mnist':
    dset = datasets.MNIST(root='./', train=True, download=False, transform=transforms.ToTensor())
elif dataset.lower() == 'fashion':
    dset = datasets.FashionMNIST(root='./', train=True, download=False, transform=transforms.ToTensor())
loader = data.DataLoader(dset, batch_size=6, shuffle=False)
image_sample, _ = next(iter(loader))


lagVAE.cpu()
lagVAE.eval()
with torch.no_grad():
    recons = lagVAE.generate_similar(image_sample, noise=False)
    for idx in range(6):
        image_plot(ax[0, idx], image_sample[idx,:])
        image_plot(ax[1, idx], torch.sigmoid(recons[idx,:]))
        
    #Create similar examples with noise
    for row in range(2, 5):
        recons = lagVAE.generate_similar(image_sample)
        for idx in range(6):
            image_plot(ax[row, idx], torch.sigmoid(recons[idx,:]))
            
fig.suptitle('LagrangianVAE image reconstruction')
plt.savefig('./images/lagvae-recon.png')
ax[0,-1].text(30, 15, 'Original')
ax[1,-1].text(30, 15, 'Reconstruction from mean')
ax[2,-1].text(30, 15, 'Noisy reconstructions')
plt.savefig('./images/MNIST-lagVAE-reconstructions.png', bbox_inches='tight')

### Reconstructions using sampling

In [None]:
fig, axes = plt.subplots(4, 4, subplot_kw=ax_settings, figsize=(6, 6))
with torch.no_grad():
    for ax_row in axes:
        for ax in ax_row:
            latent_sample = torch.randn(1, z_dim)
            out = lagVAE.decoder(latent_sample)
            image_plot(ax, torch.sigmoid(out[0]))
        
fig.suptitle('Image generation with LagrangianVAE')
plt.savefig('./images/MNIST-lagVAE-latent_samples.png')

### Distribution of means

Below we use Linear Discriminant Analysis as a dimension reduction tool to see the distribution of means. The image below shows us the result of the encodings on validation examples.

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
lda = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis(n_components=2))

In [None]:
legend_label = model.val_dataloader.dataset.dataset.classes

model.eval()
with t.no_grad():
    means = []
    labels = []
    for image, label in model.val_dataloader:
        mean, _ = lagVAE.encoder(image)
        means.append(mean)
        labels.append(label)
    means = torch.cat(means)
    labels = torch.cat(labels)
    lda.fit(means, labels)
    X_embed = lda.transform(means)

In [None]:
fig, ax = plt.subplots(1, 1)
class_labels = model.val_dataloader.dataset.dataset.classes
scatter = ax.scatter(X_embed[:, 0], X_embed[:, 1], c=labels, s=15, cmap='tab10')
handles, _ = scatter.legend_elements()
legend1 = ax.legend(handles, class_labels, bbox_to_anchor=(1.05, 1), title="Classes")
ax.add_artist(legend1)

plt.title('Distribution of encoded means')
plt.savefig('./images/lagVAE-mean-distribution.png', bbox_inches='tight')