# About VAE's

VAE's deal in latent $x$ and observed variables $z$. The assumption being that the distribution of an observed variable $p(x)$, can be written in terms of its latent variables $p(x ~|~ z) \cdot p(z)$. It is thus the job of the VAE to learn a latent space $p(z) \in P_z$ and conditional distribution $p(x ~|~ z) \in P_{x ~|~ z}$, to form a variable $p(x, z) \in P_{(x, z)}$ which approximates our observed variable $x$. 

$$P_{(x, z)} = \{p(x, z) ~|~ p(z) \in P_z, ~~ p(x ~|~ z) \in P_{x ~|~ z} \}$$

## Loss Functions
One way to measure our approximation is to use KL divergence
$$D_{KL}(p_{data}(x)~||~ p_{vae}(x)) = E\left[ \log \frac{P_{data}}{P_{vae}} \right]$$

**Lemma**: minimizing KL Divergence is equivalent to maximizing log-likelihood
$$ E\left[ \log \frac{P_{data}}{P_{vae}} \right] = E\left[ \log P_{data} - \log P_{vae} \right] = E\left[ \log P_{data}] - E[\log P_{vae} \right] = \text{Entropy $P_{data}$} - \frac{1}{N} \sum \log p_{vae}(x)$$
And $\sum \log p_{vae}(x) = \text{log likelihood}$

Thus for our loss we need to find
$$\max \sum \log p_{vae}(x) = \max \sum \log \int p(x, z) dz$$

The primary issue with this being that it is computationally intractable as the size of the latent space grows. 

### Handling a large Latent Space 
So we need to use some kind of approximation
$$ \log \int p(x, z) dz = \log \int \frac{q_\lambda}{q_\lambda} p(x, z) $$

By Jensen's Inequality we can see that
$$\log \int \frac{q_\lambda}{q_\lambda} p(x, z) \ge  \int q_\lambda \log \left( \frac{p(x, z)}{q_\lambda}\right) = E\left[\log \left( \frac{p(x, z)}{q_\lambda}\right)\right]$$
This is called **ELBO** or Evidence Lower Bound **NOT FINISHED**

In [14]:
import torch
from torch import nn

In [None]:
class VAE(nn.Module):

    def __init__(self):
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z):

        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self, *args, **kwargs):
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self, num_samples, current_device, **kwargs):
        z = torch.randn(num_samples, self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x, **kwargs):
        return self.forward(x)[0]

## References
- [Deep Generative Models @ Stanford](https://deepgenerativemodels.github.io/notes/vae/)
- [Deep Unsupervised Modela @ CAL](https://sites.google.com/view/berkeley-cs294-158-sp20/home)
- [Proof of Lemma](https://wiseodd.github.io/techblog/2017/01/26/kl-mle/)
- [PyTorch Reference Implementation](https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py)