# An Introduction to the Wasserstein Auto-encoder

-------

## Authors
Joel Dapello<br>
Michael Sedelmeyer<br>
Wenjun Yan

-------

<a id="top"></a>
## Contents

Table of contents with markdown hyperlinks to each section of the notebook

1. [Motivation and background](#intro)

1. [Conceptual foundations](#concepts)

1. [Mathematics and algorithms](#details)

1. [Comparing results on MNIST](#mnist)

1. [Comparing results on FashionMNIST](#fmnist)

1. [Conclusions and further analysis](#conclusion)

1. [References and further reading](#sources)


- [Appendices: PyTorch Implementation](#appendix)
    - [Appendix A: Auto-encoder](#ae)
    - [Appendix B: Variational auto-encoder](#vae)
    - [Appendix C: Wasserstein auto-encoder](#wae)
    - [Appendix C: Plotting functions](#plots)

<a id="intro"></a>
## Motivation and background
[return to top](#top)

Designing generative models capabale of capturing the structure of very high dimensional data is a standing problem in the field of statistical modeling. One class of models that have proved effective for this task is the auto-encoder (AE). AEs are neural network based models that assume the high dimensional data being modeled can be reduced to a lower dimensional manifold, defined on a space of latent variables. To do this, the AE defines an encoder network $Q$ which maps a high dimensional input to a low dimensional latent space $Z$, and a generator network $G$ which maps $Z$ back to the high dimensional input space. The whole system is trained end to end with stochastic gradient descent, where, in the case of the vanilla AE, the cost function is designed to minimize the distance between the training data $X$ and it's reconstruction, $\hat{X} = G(Q(X))$. While the standard AE is quite effective at learning a low dimensional representation of the training data, it is prone to overfitting, and typically fails as a generative model. This is because with no constraint on the shape of the learned representation in latent space, it is unclear how to effectively sample from $Z$ -- passing randomly draw latent codes which are far from the those that G has learned to decode often lead to the generation of nonsense.

The well-known variational auto-encoder (VAE) (Kingma & Welling, 2014) was introduced as a solution to this problem. The VAE builds on the AE frame work with a modified cost function designed to maximize the evidence lower bound between the model and target distribution. This effectively introduces a regularization penalty which pushes $Q_z=Q(Z|X=x)$ to match a specified prior distribution, $P_z$. Thus, the VAE functions as a much more powerful generative model than the standard AE, because samples drawn from the $P_z$ are in a range that the $G$ has learned to generate from. Unfortunately, while the VAE performs admirably on simple datasets such as MNIST, with more complex datasets the VAE tends to recreate blurred samples.

In 2018 with the Internation Conference on Learning Representations paper "Wasserstein Auto-Encoders", the authors Tolstikhin et. al. propose the Wasserstein auto-encoder (WAE) as a new algorithm for building a latent-variable-based generative model. This new addition to the family of regularized auto-encoders aims to minimize the optimal transport cost, $\mathcal{D}_Z(Q_Z,P_Z)$ (Villani, 2003) formulated as the Wasserstein distance between the model distribution $Q_Z$ and the target $P_Z$ distribution. This can be thought of intuitively as the cost to transform one distribution into another, and leads to a different regularization penalty than that of the VAE. The WAE regularizer encourages the full encoded training distribution to form a continuous mixturing matching the $P_Z$ rather than individual samples as happens in the case of the VAE (see [Figure 1](#fig1)). For this reason, the WAE shares many of the properties of VAEs, while generating better quality samples due to a better disentangling of the latent space due to the optimal transport penalty.

In this tutorial, we implement the generative adversarial network (GAN) formulation of WAE (WAEgan). The WAEgan uses the Kantorovich-Rubinstein duality (CITE), expressed as an adversarial objective on the latent space. Specifically, the WAEgan implements a discriminator network $D$ in the latent space $Z$ trying to differentiate between samples drawn from $P_Z$ and samples drawn from $Q_Z$, essentially setting $\mathcal{D}_Z(Q_Z,P_Z)=D(Q_Z,P_Z)$, and forcing $Q$ to learn to generate latent codes that fool the discriminator $D$. In addition to implementing the WAEgan, we implement a VAE and vanilla AE as well. We choose this approach because, to better understand the WAE and its benefits, it is important to consider WAE within the context of these two preceeding and well-established algorithms. This approach provides a more intuitive understanding of the results by demonstrating side-by-side comparisons of each algorithm applied to the popular MNIST (CITE) and FashionMNIST (CITE) datasets with convolutional nueral network (CNN) implementations in PyTorch. 

<a id="fig1"></a>
**Figure 1:** Conceptual comparison of AE reconstruction methods (after Tolsikhin, et.al 2018). All three algorithms map inputs $x \in X$ to a latent code $z \in Z$ and then attempt to reconstruct $\hat{x}=G(z)$. The AE places no regularization penalty on $Z$, while the VAE and WAE use Kullback–Leibler divergence (KLD) and optimal transport cost respectively to penalize divergence of $Q_Z$ from the shape of the prior, $P_Z$. While KLD forces Q(Z|X=x) to match $P_Z$, the optimal transport cost enforces the continuous mixture $Q_z:=\int Q(Z|X) dP_x$ to match $P_Z$.

![alt text](https://github.com/sedelmeyer/wasserstein-auto-encoder/blob/master/images/figure%201%20-%20reconstruction.png?raw=true "Title")

<a id="details"></a>
## Mathematics and algorithms
[return to top](#top)

In this section we provide the mathematical detail and algorithmic differences between each method, paying extra attention to WAE and how it varies from VAE.

**latex to include:**
1. notational algorithms
1. loss function detail
1. mathematical representation of the reparameterization trick

**images to include:**
1. A small graphical representation of the reparameterization trick (small and simple node/edge plot)

<a id="mnist"></a>
## Comparing results on MNIST
[return to top](#top)

In this section we specify the parameters used in our model and provide plots and metrics and written interpretation describing the training results and latent space representations of our algorithms on MNIST

**images/tables to include:**
1. Sample of 5 original MNIST images and corresponding decoded images for AE, VAE, and WAE on separate rows
1. Latent space linear interpolation results of each model, pixel space vs AE vs VAE vs WAE on separate rows
1. tSNE or PCA representation of pixel space vs latent space for each model to demonstrate differences
1. table summarizing comparative loss (and if possible FID results)

<a id="fmnist"></a>
## Comparing results on FashionMNIST
[return to top](#top)

Same as above for MNIST

**images/tables to include:**
1. same as above for MNIST, but probably smaller and with fewer examples if results demonstrate similar characteristics

<a id="conclusion"></a>
## Conclusions and further analysis
[return to top](#top)

Here we summarize our conclusions given MNIST and FMNIST, but also describe other dataset we may want to run as comparison (e.g. celeb faces for representation on a low manifold surface such a faces, RNA expression data for investigation of a novel application of WAE)

<a id="conclusion"></a>
## References and Further Reading
[return to top](#top)

Cite the papers, repos, datasets, and blogs we used in our analysis, as well as any other resources we want to direct our readers toward

1. VAE paper
1. WAE paper
1. PyTorch/resources implementation of VAE
1. AE paper?
1. MNIST
1. FashionMNIST

<a id="appendix"></a>
## Appendices: PyTorch Implementation
[return to top](#top)

- The Appendix is where we lay out and run our PyTorch code, each model is separated among sub-appendices
- We should output our most important plots to png (saved on GitHub) so we can display them via markdown img link at the appropriate locations in our paper

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [6]:
args = {}
args['dim_h'] = 40            # factor controlling size of hidden layers
args['n_channel'] = 1         # number of channels in the input data (MNIST is 1, aka greyscale)
args['n_z'] = 20              # number of dimensions in latent space. 
args['sigma'] = 1.0           # variance in n_z
args['lambda'] = 0.01         # hyper param for weight of discriminator loss
args['lr'] = 0.0002           # learning rate for Adam optimizer
args['epochs'] = 50           # how many epochs to run for
args['batch_size'] = 256      # batch size for SGD
args['save'] = False          # save weights at each epoch of training if True
args['train'] = False         # train networks if True, else load networks from saved weights
args['dataset'] = 'mnist'     # specify which dataset to use

In [7]:
## load Dataset
if args['dataset'] == 'mnist':
    trainset = datasets.MNIST(
        root='./MNIST/',
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )

    testset = datasets.MNIST(
        root='./MNIST/',
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )
elif args['dataset'] == 'fmnist':
    trainset = datasets.FashionMNIST(
        root='./FMNIST/',
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )

    testset = datasets.FashionMNIST(
        root='./FMNIST/',
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )
    
train_loader = DataLoader(
    dataset=trainset,
    batch_size=args['batch_size'],
    shuffle=True
)

test_loader = DataLoader(
    dataset=testset,
    batch_size=args['batch_size'],
    shuffle=False
)

<a id="ae"></a>
### Appendix A: Auto-encoder 
[return to top](#top)

In [8]:
## create encoder model and decoder model
class AE_Encoder(nn.Module):
    def __init__(self, args):
        super(AE_Encoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # convolutional filters organized according to the popular DCGAN (Radford et. al., 2015) framework, excellent for image data
        self.conv = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        
        # final layer is fully connected
        self.fc = nn.Linear(self.dim_h * (2 ** 3), self.n_z)

    def forward(self, x):
        x = self.conv(x)
        x = x.squeeze()
        x = self.fc(x)
        return x

class AE_Decoder(nn.Module):
    def __init__(self, args):
        super(AE_Decoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # first layer is fully connected
        self.fc = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 8 * 7 * 7),
            nn.ReLU()
        )

        # deconvolutional filters, essentially the inverse of convolutional filters
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.dim_h * 8, 7, 7)
        x = self.deconv(x)
        return x

In [9]:
# instantiate models, and set to train mode
ae_encoder, ae_decoder = AE_Encoder(args), AE_Decoder(args)

if args['train']:
    # specify loss (mean squared error of image reconstruction)
    criterion = nn.MSELoss()

    # use the Adam optimizer, it's always a good choice
    enc_optim = torch.optim.Adam(ae_encoder.parameters(), lr = args['lr'])
    dec_optim = torch.optim.Adam(ae_decoder.parameters(), lr = args['lr'])

    enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=30, gamma=0.5)
    dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=30, gamma=0.5)

    for epoch in range(args['epochs']):
        for images, _ in tqdm(train_loader):
            ae_encoder.train()
            ae_decoder.train()

            ae_encoder.zero_grad()
            ae_decoder.zero_grad()
            batch_size = images.size()[0]

            z_hat = ae_encoder(images)
            x_hat = ae_decoder(z_hat)
            train_recon_loss = criterion(x_hat, images)

            train_recon_loss.backward()

            enc_optim.step()
            dec_optim.step()

        # Run validation set
        ae_encoder.eval()
        ae_decoder.eval()
        for images, _ in tqdm(test_loader):
            z_hat = ae_encoder(images)
            x_hat = ae_decoder(z_hat)
            test_recon_loss = criterion(x_hat, images)

        if args['save']:
            save_path = './save/AE_{}-epoch_{}.pth'
            torch.save(ae_encoder.state_dict(), save_path.format('encoder', epoch))
            torch.save(ae_decoder.state_dict(), save_path.format('decoder', epoch))

        print("Epoch: [{}/{}], \tTrain Reconstruction Loss: {}\n\t\t\tTest Reconstruction Loss: {}".format(
            epoch + 1, 
            args['epochs'], 
            train_recon_loss.data.item(),
            test_recon_loss.data.item()
        ))
else:
    # load encoder and decoder weights from checkpoint
    enc_checkpoint = torch.load('save/AE_encoder-best_{}.pth'.format(args['dataset']))
    ae_encoder.load_state_dict(enc_checkpoint)

    dec_checkpoint = torch.load('save/AE_decoder-best_{}.pth'.format(args['dataset']))
    ae_decoder.load_state_dict(dec_checkpoint)

<a id="vae"></a>
### Appendix B: Variational auto-encoder
[return to top](#top)

In [10]:
## create encoder model and decoder model
class VAE_Encoder(nn.Module):
    def __init__(self, args):
        super(VAE_Encoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # convolutional filters, work excellent with image data
        self.conv = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        
        # final layer is fully connected
        self.fc1 = nn.Linear(self.dim_h * (2 ** 3), self.n_z)
        self.fc2 = nn.Linear(self.dim_h * (2 ** 3), self.n_z)

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
        
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    
    def forward(self, x):
        h = self.conv(x)
        h = h.squeeze()
        z, mu, logvar = self.bottleneck(h)
        
        return z, mu, logvar

class VAE_Decoder(nn.Module):
    def __init__(self, args):
        super(VAE_Decoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # first layer is fully connected
        self.fc = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 8 * 7 * 7),
            nn.ReLU()
        )

        # deconvolutional filters, essentially the inverse of convolutional filters
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.dim_h * 8, 7, 7)
        x = self.deconv(x)
        return x

In [11]:
# instantiate models, and set to train mode
vae_encoder, vae_decoder = VAE_Encoder(args), VAE_Decoder(args)

if args['train']:
    # specify loss (mean squared error of pixel by pixel image reconstruction)
    criterion = nn.MSELoss()

    # use the Adam optimizer, it's always a good choice
    enc_optim = torch.optim.Adam(vae_encoder.parameters(), lr = args['lr'])
    dec_optim = torch.optim.Adam(vae_decoder.parameters(), lr = args['lr'])

    enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=30, gamma=0.5)
    dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=30, gamma=0.5)

    for epoch in range(args['epochs']):
        for images, _ in tqdm(train_loader):
            vae_encoder.train()
            vae_decoder.train()

            vae_encoder.zero_grad()
            vae_decoder.zero_grad()
            batch_size = images.size()[0]

            z_hat, mu, logvar = vae_encoder(images)
            x_hat = vae_decoder(z_hat)
            
            BCE = nn.functional.binary_cross_entropy(
                x_hat.view(-1,784), 
                images.view(-1, 784), 
                reduce=False
            ).sum()
            
            KLD = 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            
            ELBO = BCE - KLD
            ELBO.backward()
                        
            enc_optim.step()
            dec_optim.step()

        # Run validation set
        vae_encoder.eval()
        vae_decoder.eval()
        for images, _ in tqdm(test_loader):
            z_hat, mu, logvar = vae_encoder(images)
            x_hat = vae_decoder(z_hat)
            test_recon_loss = criterion(x_hat, images) # maybe change to BCE?

        if args['save']:
            save_path = './save/VAE_{}-epoch_{}.pth'
            torch.save(vae_encoder.state_dict(), save_path.format('encoder', epoch))
            torch.save(vae_decoder.state_dict(), save_path.format('decoder', epoch))

        print("Epoch: [{}/{}], \tTrain Reconstruction Loss: {} \tKLD:{}\n\t\t\tTest Reconstruction Loss: {}".format(
            epoch + 1, 
            args['epochs'], 
            BCE.data.item(),
            KLD.data.item(),
            test_recon_loss.data.item()
        ))
else:
    # load encoder and decoder weights from checkpoint
    enc_checkpoint = torch.load('save/VAE_encoder-best_{}.pth'.format(args['dataset']))
    vae_encoder.load_state_dict(enc_checkpoint)

    dec_checkpoint = torch.load('save/VAE_decoder-best_{}.pth'.format(args['dataset']))
    vae_decoder.load_state_dict(dec_checkpoint)

<a id="wae"></a>
### Appendix C: Wasserstein auto-encoder
[return to top](#top)

In [12]:
## create encoder model and decoder model
class WAE_Encoder(nn.Module):
    def __init__(self, args):
        super(WAE_Encoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # convolutional filters, work excellent with image data
        self.conv = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        
        # final layer is fully connected
        self.fc = nn.Linear(self.dim_h * (2 ** 3), self.n_z)

    def forward(self, x):
        x = self.conv(x)
        x = x.squeeze()
        x = self.fc(x)
        return x

class WAE_Decoder(nn.Module):
    def __init__(self, args):
        super(WAE_Decoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # first layer is fully connected
        self.fc = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 8 * 7 * 7),
            nn.ReLU()
        )

        # deconvolutional filters, essentially the inverse of convolutional filters
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.dim_h * 8, 7, 7)
        x = self.deconv(x)
        return x

# define the descriminator
class Discriminator(nn.Module):
    def __init__(self, args):
        super(Discriminator, self).__init__()

        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # main body of discriminator, returns [0,1]
        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x
    
# control which parameters are frozen / free for optimization
def free_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def frozen_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [13]:
# instantiate discriminator model, and restart encoder and decoder, for fairness. Set to train mode, etc
wae_encoder, wae_decoder, discriminator = WAE_Encoder(args), WAE_Decoder(args), Discriminator(args)

criterion = nn.MSELoss()

if args['train']:
    enc_optim = torch.optim.Adam(wae_encoder.parameters(), lr = args['lr'])
    dec_optim = torch.optim.Adam(wae_decoder.parameters(), lr = args['lr'])
    dis_optim = torch.optim.Adam(discriminator.parameters(), lr = args['lr'])

    enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=30, gamma=0.5)
    dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=30, gamma=0.5)
    dis_scheduler = torch.optim.lr_scheduler.StepLR(dis_optim, step_size=30, gamma=0.5)

    # one and -one allow us to control descending / ascending gradient descent
    one = torch.Tensor([1])
    
    for epoch in range(args['epochs']):

        # train for one epoch -- set nets to train mode
        wae_encoder.train()
        wae_decoder.train()
        discriminator.train()

        for images, _ in tqdm(train_loader):
            # zero gradients for each batch
            wae_encoder.zero_grad()
            wae_decoder.zero_grad()
            discriminator.zero_grad()

            # ======== Train Discriminator ======== #

            # freeze auto encoder params
            frozen_params(wae_decoder)
            frozen_params(wae_encoder)

            # free discriminator params
            free_params(discriminator)

            # run discriminator against randn draws
            z = torch.randn(images.size()[0], args['n_z']) * args['sigma']
            d_z = discriminator(z)

            # run discriminator against encoder z's
            z_hat = wae_encoder(images)
            d_z_hat = discriminator(z_hat)

            d_z_loss = args['lambda']*torch.log(d_z).mean()
            d_z_hat_loss = args['lambda']*torch.log(1 - d_z_hat).mean()

            # formula for ascending the descriminator -- -one reverses the direction of the gradient.
            d_z_loss.backward(-one)
            d_z_hat_loss.backward(-one)

            dis_optim.step()

            # ======== Train Generator ======== #

            # flip which networks are frozen, which are not
            free_params(wae_decoder)
            free_params(wae_encoder)
            frozen_params(discriminator)

            batch_size = images.size()[0]

            # run images
            z_hat = wae_encoder(images)
            x_hat = wae_decoder(z_hat)

            # discriminate latents
            z_hat2 = wae_encoder(Variable(images.data))
            d_z_hat = discriminator(z_hat2)

            # calculate reconstruction loss
            # WAE is happy with whatever cost function, let's use BCE
            BCE = nn.functional.binary_cross_entropy(
                x_hat.view(-1,784), 
                images.view(-1, 784), 
                reduce=False
            ).mean()
            
            # calculate discriminator loss
            d_loss = args['lambda'] * (torch.log(d_z_hat)).mean()
            
            # we keep the BCE and d_loss on separate graphs to increase efficiency in pytorch
            BCE.backward(one)
            # -one reverse the direction of the gradient, minimizing BCE - d_loss
            d_loss.backward(-one)

            enc_optim.step()
            dec_optim.step()

        # test on test set
        wae_encoder.eval()
        wae_decoder.eval()
        for images, _ in tqdm(test_loader):
            z_hat = wae_encoder(images)
            x_hat = wae_decoder(z_hat)
            test_recon_loss = criterion(x_hat, images)

        
        if args['save']:
            save_path = './save/WAEgan_{}-epoch_{}.pth'
            torch.save(wae_encoder.state_dict(), save_path.format('encoder', epoch))
            torch.save(wae_decoder.state_dict(), save_path.format('decoder', epoch))
            torch.save(discriminator.state_dict(), save_path.format('discriminator', epoch))

        # print stats after each epoch
        print("Epoch: [{}/{}], \tTrain Reconstruction Loss: {} d loss: {}, \n\t\t\tTest Reconstruction Loss:{}".format(
            epoch + 1, 
            args['epochs'], 
            BCE.data.item(),
            d_loss.data.item(),
            test_recon_loss.data.item()
        ))
        
else:
    enc_checkpoint = torch.load('save/WAEgan_encoder-best_{}.pth'.format(args['dataset']))
    wae_encoder.load_state_dict(enc_checkpoint)

    dec_checkpoint = torch.load('save/WAEgan_decoder-best_{}.pth'.format(args['dataset']))
    wae_decoder.load_state_dict(dec_checkpoint)
    
    dec_checkpoint = torch.load('save/WAEgan_discriminator-best_{}.pth'.format(args['dataset']))
    discriminator.load_state_dict(dec_checkpoint)

<a id="plots"></a>
### Appendix D: Plotting functions
[return to top](#top)

**BOTH ALGORITHMS SHOULD PROBABLY BE REWRITTEN USING LATEX IN A WAY THAT TAKES UP LESS VERTICAL SPACE**

<a id="algo2"></a>
**Algorithm 2:** Wassertein auto-encoder with GAN-based penalty (WAE-GAN) pseudocode

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters fo the encoder $Q_{\phi}$, decoder $G_{\theta}$, and latent discriminator $D_{\gamma}$.

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{z_1, \dotsc , z_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $D_{\gamma}$ by ascending:
$$\frac{\lambda}{n}\sum_{i=1}^n log \; D_{\gamma}(z_i) + log (1-D_{\gamma}(\tilde{z}_i))$$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i)) - \lambda \cdot log\;D_{\gamma}(\tilde{z}_i)$$
> **end while**

<a id="algo1"></a>
**Algorithm 1:** Variational auto-encoder pseudocode for computing a stochastic graient using the estimator

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters for the encoder $Q_{\phi}$ and decoder $G_{\theta}$

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{\epsilon_1, \dotsc , \epsilon_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i))$$
> **end while**