# An Introduction to the Wasserstein Auto-encoder

-------

## Authors
Joel Dapello<br>
Michael Sedelmeyer<br>
Wenjun Yan

-------

<a id="top"></a>
## Contents

Table of contents with markdown hyperlinks to each section of the notebook

1. [Motivation and background](#intro)

1. [Conceptual foundations](#concepts)

1. [Mathematics and algorithms](#details)

1. [Comparing results on MNIST](#mnist)

1. [Comparing results on FashionMNIST](#fmnist)

1. [Conclusions and further analysis](#conclusion)

1. [References and further reading](#sources)


- [Appendices: PyTorch Implementation](#appendix)
    - [Appendix A: Auto-encoder](#ae)
    - [Appendix B: Variational auto-encoder](#vae)
    - [Appendix C: Wasserstein auto-encoder](#wae)
    - [Appendix C: Plotting functions](#plots)

<a id="intro"></a>
## Motivation and background
[return to top](#top)

Designing generative models capabale of capturing the structure of very high dimensional data is a standing problem in the field of statistical modeling. One class of models that have proved effective for this task is the auto-encoder (AE). AEs are neural network based models that assume the high dimensional data being modeled can be reduced to a lower dimensional manifold, defined on a space of latent variables. To do this, the AE defines an encoder network $Q$ which maps a high dimensional input to a low dimensional latent space $Z$, and a generator network $G$ which maps $Z$ back to the high dimensional input space. The whole system is trained end to end with stochastic gradient descent, where, in the case of the vanilla AE, the cost function is designed to minimize the distance between the training data $X$ and it's reconstruction, $\hat{X} = G(Q(X))$. While the standard AE is quite effective at learning a low dimensional representation of the training data, it is prone to overfitting, and typically fails as a generative model. This is because with no constraint on the shape of the learned representation in latent space, it is unclear how to effectively sample from $Z$ -- passing randomly draw latent codes which are far from the those that G has learned to decode often lead to the generation of nonsense.

The well-known variational auto-encoder (VAE) (Kingma & Welling, 2014) was introduced as a solution to this problem. The VAE builds on the AE frame work with a modified cost function designed to maximize the evidence lower bound between the model and target distribution. This effectively introduces a regularization penalty which pushes $Q_z=Q(Z|X=x)$ to match a specified prior distribution, $P_z$. Thus, the VAE functions as a much more powerful generative model than the standard AE, because samples drawn from the $P_z$ are in a range that the $G$ has learned to generate from. Unfortunately, while the VAE performs admirably on simple datasets such as MNIST, with more complex datasets the VAE tends to recreate blurred samples.

In 2018 with the Internation Conference on Learning Representations paper "Wasserstein Auto-Encoders", the authors Tolstikhin et. al. propose the Wasserstein auto-encoder (WAE) as a new algorithm for building a latent-variable-based generative model. This new addition to the family of regularized auto-encoders aims to minimize the optimal transport cost, $\mathcal{D}_Z(Q_Z,P_Z)$ (Villani, 2003) formulated as the Wasserstein distance between the model distribution $Q_Z$ and the target $P_Z$ distribution. This can be thought of intuitively as the cost to transform one distribution into another, and leads to a different regularization penalty than that of the VAE. The WAE regularizer encourages the full encoded training distribution to form a continuous mixturing matching the $P_Z$ rather than individual samples as happens in the case of the VAE (see [Figure 1](#fig1)). For this reason, the WAE shares many of the properties of VAEs, while generating better quality samples due to a better disentangling of the latent space due to the optimal transport penalty.

In this tutorial, we implement the generative adversarial network (GAN) formulation of WAE (WAEgan). The WAEgan uses the Kantorovich-Rubinstein duality (CITE), expressed as an adversarial objective on the latent space. Specifically, the WAEgan implements a discriminator network $D$ in the latent space $Z$ trying to differentiate between samples drawn from $P_Z$ and samples drawn from $Q_Z$, essentially setting $\mathcal{D}_Z(Q_Z,P_Z)=D(Q_Z,P_Z)$, and forcing $Q$ to learn to generate latent codes that fool the discriminator $D$. In addition to implementing the WAEgan, we implement a VAE and vanilla AE as well. We choose this approach because, to better understand the WAE and its benefits, it is important to consider WAE within the context of these two preceeding and well-established algorithms. This approach provides a more intuitive understanding of the results by demonstrating side-by-side comparisons of each algorithm applied to the popular MNIST (CITE) and FashionMNIST (CITE) datasets with convolutional nueral network (CNN) implementations in PyTorch. 

<a id="fig1"></a>
**Figure 1:** Conceptual comparison of AE reconstruction methods (after Tolsikhin, et.al 2018). All three algorithms map inputs $x \in X$ to a latent code $z \in Z$ and then attempt to reconstruct $\hat{x}=G(z)$. The AE places no regularization penalty on $Z$, while the VAE and WAE use Kullback–Leibler divergence (KLD) and optimal transport cost respectively to penalize divergence of $Q_Z$ from the shape of the prior, $P_Z$. While KLD forces Q(Z|X=x) to match $P_Z$, the optimal transport cost enforces the continuous mixture $Q_z:=\int Q(Z|X) dP_x$ to match $P_Z$.

![alt text](https://github.com/sedelmeyer/wasserstein-auto-encoder/blob/master/images/figure%201%20-%20reconstruction.png?raw=true "Title")

<a id="details"></a>
## Mathematics and algorithms
[return to top](#top)

In this section we provide the mathematical detail and algorithmic differences between each method, paying extra attention to WAE and how it varies from VAE.

**latex to include:**
1. notational algorithms
1. loss function detail
1. mathematical representation of the reparameterization trick

**images to include:**
1. A small graphical representation of the reparameterization trick (small and simple node/edge plot)

<a id="mnist"></a>
## Comparing results on MNIST
[return to top](#top)

In this section we specify the parameters used in our model and provide plots and metrics and written interpretation describing the training results and latent space representations of our algorithms on MNIST

**images/tables to include:**
1. Sample of 5 original MNIST images and corresponding decoded images for AE, VAE, and WAE on separate rows
1. Latent space linear interpolation results of each model, pixel space vs AE vs VAE vs WAE on separate rows
1. tSNE or PCA representation of pixel space vs latent space for each model to demonstrate differences
1. table summarizing comparative loss (and if possible FID results)

<a id="fmnist"></a>
## Comparing results on FashionMNIST
[return to top](#top)

Same as above for MNIST

**images/tables to include:**
1. same as above for MNIST, but probably smaller and with fewer examples if results demonstrate similar characteristics

<a id="conclusion"></a>
## Conclusions and further analysis
[return to top](#top)

Here we summarize our conclusions given MNIST and FMNIST, but also describe other dataset we may want to run as comparison (e.g. celeb faces for representation on a low manifold surface such a faces, RNA expression data for investigation of a novel application of WAE)

<a id="conclusion"></a>
## References and Further Reading
[return to top](#top)

Cite the papers, repos, datasets, and blogs we used in our analysis, as well as any other resources we want to direct our readers toward

1. VAE paper
1. WAE paper
1. PyTorch/resources implementation of VAE
1. AE paper?
1. MNIST
1. FashionMNIST

<a id="appendix"></a>
## Appendices: PyTorch Implementation
[return to top](#top)

- The Appendix is where we lay out and run our PyTorch code, each model is separated among sub-appendices
- We should output our most important plots to png (saved on GitHub) so we can display them via markdown img link at the appropriate locations in our paper

In [None]:
# Import libraries
# Set parameter args
# load data train and test sets

<a id="ae"></a>
### Appendix A: Auto-encoder 
[return to top](#top)

<a id="vae"></a>
### Appendix B: Variational auto-encoder
[return to top](#top)

<a id="wae"></a>
### Appendix C: Wasserstein auto-encoder
[return to top](#top)

<a id="plots"></a>
### Appendix D: Plotting functions
[return to top](#top)

**BOTH ALGORITHMS SHOULD PROBABLY BE REWRITTEN USING LATEX IN A WAY THAT TAKES UP LESS VERTICAL SPACE**

<a id="algo2"></a>
**Algorithm 2:** Wassertein auto-encoder with GAN-based penalty (WAE-GAN) pseudocode

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters fo the encoder $Q_{\phi}$, decoder $G_{\theta}$, and latent discriminator $D_{\gamma}$.

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{z_1, \dotsc , z_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $D_{\gamma}$ by ascending:
$$\frac{\lambda}{n}\sum_{i=1}^n log \; D_{\gamma}(z_i) + log (1-D_{\gamma}(\tilde{z}_i))$$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i)) - \lambda \cdot log\;D_{\gamma}(\tilde{z}_i)$$
> **end while**

<a id="algo1"></a>
**Algorithm 1:** Variational auto-encoder pseudocode for computing a stochastic graient using the estimator

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters for the encoder $Q_{\phi}$ and decoder $G_{\theta}$

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{\epsilon_1, \dotsc , \epsilon_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i))$$
> **end while**