# VAE with Normalizing Flows (Planar Flows)

### Theory of Vanilla VAE

In a vanilla VAE we feed $x$ into an encoder neural network and obtain $(\mu, \log \sigma)$. These are the parameters of our approximate distribution 

$$
q_{\phi}(z \mid x) = \mathcal{N}(z \mid \mu_{\phi}(x), \sigma^2_{\phi}(x) I)
$$

We then get a sample $z \sim q_{\phi}(z \mid x)$ by using the reparametrization trick $z = \mu + \sigma \epsilon$ where $\epsilon \sim \mathcal{N}(0, 1)$. We then use the following objective function

$$
\mathcal{L}_{\phi, \theta}(x) = \mathbb{E}_{q_{\phi}(z \mid x)}[\log p_{\theta}(x \mid z)] - \text{KL}(q_{\phi}(z \mid x) \parallel p(z))
$$

where we compute the KL divergence in closed form using 

$$
\text{KL}(q_{\phi}(z \mid x) \parallel p(z)) = -\frac{1}{2}\sum_{j=1}^{\text{dim}(z)} \left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right)
$$

and we can compute the reconstruction error in two simple cases: Bernoulli and Normal. In the Bernoulli case (e.g. when working with images) we have

$$
p_{\theta}(x \mid z) = \prod_{i=1}^{\text{dim}(x)} p_i(z)^{x_i}(1 - p_i(z))^{1 - x_i}
$$
where
$$
p = (p_1(z), \ldots, p_{\text{dim}(x)}(z))^\top 
$$
is the output of the decoder network: $z \longrightarrow $ Decoder $ \longrightarrow p \in [0, 1]^{\text{dim}(x)}$. This means that we can write the reconstruction error as:

$$
\begin{align}
    \mathbb{E}_{q_{\phi}(z \mid x)}[\log p_{\theta}(x \mid z)]
    &=  \mathbb{E}_{q_{\phi}(z \mid x)}\left[\log \prod_{i=1}^{\text{dim}(x)} p_i(z)^{x_i}(1 - p_i(z))^{1-  x_i}\right] \\
    &= \mathbb{E}_{q_{\phi}(z \mid x)}\left[\sum_{i=1}^{\text{dim}(x)} x_i \log p_i(z) + (1 - x_i) \log(1 - p_i(z))\right] \\
    &\approx \sum_{j=1}^{n_{z}}\sum_{i=1}^{\text{dim}(x)} x_i \log p_i(z) + (1 - x_i) \log(1 - p_i(z)) \qquad z^{(j)} \sim q_{\phi}(z \mid x)
\end{align}
$$

where $n_z$ is the number of samples that we sample from $q_{\phi}(z \mid x)$. Usually, we simply set $n_z = 1$, that is we only sample one latent variable for each datapoint. This leads to the following objective function:

$$
\mathcal{L}_{\phi, \theta}(x) = \sum_{i=1}^{\text{dim}(x)} x_i \log p_i(z^{(j)}) + (1 - x_i) \log(1 - p_i(z^{(j)})) +\frac{1}{2}\sum_{j=1}^{\text{dim}(z)} \left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right)
$$

which was coded as follow:

```python
def vae_loss(image, reconstruction, mu, logvar):
  """Loss for the Variational AutoEncoder."""
  # Compute the binary_crossentropy.
  recon_loss = F.binary_cross_entropy(
      input=reconstruction.view(-1, 28*28),    # input is p(z) (the mean reconstruction)
      target=image.view(-1, 28*28),            # target is x   (the true image)
      reduction='sum'                          
  )
  # Compute KL divergence using formula (closed-form)
  kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return reconstruction_loss + kl
```

### Theory of Normalizing Flows VAE

Here we not only want our Encoder to output $(\mu, \log \sigma)$ to shift and scale our standard normal $\epsilon \sim \mathcal{N}(0, 1)$. We also want to feed it through a series of transformations depending on some parameters $\lambda$. In particular, we would like our Encoder to work as follows:

$$
x \longrightarrow \text{Encoder} \longrightarrow (\mu, \log\sigma, \lambda)
$$

then we would use $(\mu, \log\sigma)$ to compute $z_0$ using the reparametrization trick

$$
z_0 = \mu + \sigma \epsilon \qquad \epsilon \sim \mathcal{N}(0, 1)
$$

and finally, we would feed $z_0$ into a series of transformation (with parameters $\lambda$) to reach the final $z_K$:

$$
z_K = f_K \circ \ldots \circ f_2 \circ f_1(z_0)
$$

This means that our approximating distribution would not be 

$$
q_{\phi}(z \mid x) = \mathcal{N}(z \mid \mu_{\phi}(x), \sigma^2_{\phi}(x) I)
$$

anymore but rather, we would have $q_0(z_0)=\mathcal{N}(z_0 \mid \mu_{\phi}(x), \sigma_{\phi}^2(x) I)$ and 
(using the change of variables formula)

$$
\ln q_{\phi}(z \mid x) = \ln q_K(z_K) = \ln q_0(z_0) - \sum_{k=1}^K \ln \left|\text{det}\frac{\partial f_k}{\partial z_{k-1}}\right|.
$$
Thanks to the law of the uncounscious statistician we now know that 

$$
\begin{align}
\mathcal{L}_{\phi, \theta}(x) 
&= \mathbb{E}_{q_{\phi}(z \mid x)}[\log p_{\theta}(x \mid z)] - \text{KL}(q_{\phi}(z \mid x) \parallel p(z)) \\
&= \mathbb{E}_{q_K(z_K)}[\log p_{\theta}(x \mid z_K)] - \mathbb{E}_{q_K(z_K)}[\log q_K(z_K) - \log p(z_K)] \\
&= \mathbb{E}_{q_0(z_0)}[\log p_{\theta}(x \mid f_K \circ \ldots \circ f_2 \circ f_1(z_0))] - \mathbb{E}_{q_0(z_0)}[\log q_K(f_K \circ \ldots \circ f_2 \circ f_1(z_0)) - \log p(f_K \circ \ldots \circ f_2 \circ f_1(z_0))]
\end{align}
$$

as usual, we can approximate this with Monte Carlo by drawing samples $z_0 \sim q_0(z_0) = N(\mu, \sigma^2 I)$

$$
\mathcal{L}_{\phi, \theta}(x) \approx \sum_{j=1}^{n_z}\log p_{\theta}(x \mid f_K \circ \ldots \circ f_2 \circ f_1(z_0^j)) - \left[\sum_{j=1}^{n_z}\log q_K(f_K \circ \ldots \circ f_2 \circ f_1(z_0^j)) - \log p(f_K \circ \ldots \circ f_2 \circ f_1(z_0^j))\right]
$$

Again, in practice we use only one sample

$$
\begin{align}
\mathcal{L}_{\phi, \theta}(x) 
&\approx \left[\sum_{i=1}^{\text{dim}(x)} x_i \log p_i(z_K) + (1 - x_i) \log(1 - p_i(z_K))\right] - \log q_K(z_K) + \log p(z_K) \\
&= -\text{BCE}(X, z_K) - (\log q_0(z_0) + \text{LADJ}) + \log p(z_K) \\
&= -\text{BCE}(X, z_K) - \log q_0(z_0) - \text{LADJ} + \log p(z_K)
\end{align}
$$
Pycharm however does minimization only, so our objective becomes
$$
\text{BCE}(X, z_K) + \log q_0(z_0) + \text{LADJ} -\log p(z_K)
$$
Using temperature we then have
$$
\begin{align}
    \text{objective} 
    &= -\sum_{i=1}^{\text{dim}(x)} x_i \log p_i(z_K) + (1 - x_i) \log(1 - p_i(z_K)) \\
    &\quad -\frac{d}{2}\log(2\pi) -\frac{1}{2}\log \text{det}(\text{Diag}(\sigma^2))- \frac{1}{2}(x-\mu)^\top \text{Diag}\left(\frac{1}{\sigma^2}\right)(x - \mu) \\
    &\quad -\sum_{k=1}^K \log |1 + u_k^\top (1 - \tanh^2(w_k^\top z_{k-1} + b_k))w_k| \\
    &\quad -\frac{d}{2}\log(2\pi) -\frac{1}{2}x^\top x
\end{align}
$$

In our case we use the transformation

$$
f(z) = u h(w^\top z + b) \qquad u, w \in \mathbb{R}^{\text{dim}(z)\times 1} \qquad b\in\mathbb{R} \qquad h(\cdot)  =\tanh(\cdot)
$$

It's Log-Absolute-Determinant-Jacobian (LADJ) is given by

$$
\log \left|\text{det}\frac{\partial f}{\partial z}\right| = \log |1 + u^\top h'(w^\top z + b)w|
$$

where $h'$ is the derivative of $h$ and when $h = \tanh$ we have
$$
h'(\cdot) = 1 - \tanh^2(\cdot)
$$

If we apply this transformation $K$ times we get 

$$
\text{LADJ} = -\sum_{k=1}^K \log |1 + u_k^\top (1 - \tanh^2(w_k^\top z_{k-1} + b_k))w_k|
$$

**NOTE**: we have a different set of $\lambda_k = \{w_k, u_k, b_k\}$ for every transformation. 

Finally, not all transformations like this are invertible. One way to obtain an invertible transformation is by modifying $u$ after it has come out of the encoder, and replace it with

$$
\widehat{u} = u + \left[-1 + \log(1 + e^{w^\top u}) - w^\top u\right] \frac{w}{\parallel w \parallel^2}
$$