__NOTE:__ Simple autoencoder motivated to use variational autoencoder which instead of learning latent dimension, it learns latent distribution of the data at hand by assuming data follows gaussian distribution and thus learns only mean and log of variance. 

Ref for why gaussian is used: https://stats.stackexchange.com/questions/402569/why-do-we-use-gaussian-distributions-in-variational-autoencoder

Log of variance is taken since the variance can only be positive thus restrictive to learn on the other hand log of variance can range from $\left(-\infty, \infty\right)$ which does not restrict network to learn values. To learn mean and log of variance, KL Divergence was used since it measures the distance between two distribution along with reconstruction loss. And using the learned mean and log of variance sampling can be done and thus generate required data, this is done via (here z denotes latent dimension):

$$
z_{new} = \mu_{z} + \sigma_{z} * \epsilon \\
\text{where} \,\, \epsilon \sim \mathcal{N}\left(0, 1\right) \\
\\
$$

It is known that $\log{\left(\sigma_{z}^2\right)}$ say $z_{lvar}$ is learned so to get $\sigma_{z}$. We know:

$$
\sigma_{z} = \exp^{\left(\log{\left(\sigma_{z}\right)}\right)} \\

\text{Multiplying and dividing by 2} \\

\sigma_{z} = \exp^{\left(\frac{2}{2}\log{\left(\sigma_{z}\right)}\right)} \\

\sigma_{z} = \exp^{\left(\frac{2\log{\left(\sigma_{z}\right)}}{2}\right)} \\

\sigma_{z} = \exp^{\left(\frac{\log{\left(\sigma_{z}^2\right)}}{2}\right)} \\

\sigma_{z} = \exp^{\left(0.5z_{lvar}\right)}
$$

# Importing Libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
import wandb

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Helper Functions

In [2]:
def save_checkpoint(model, save_path):
    torch.save(model.state_dict(), save_path)
    
def load_checkpoint(model, save_path):
    model.load_state_dict(torch.load(save_path))
    return model

def visualization(dataloader, model, title, clear = True):
    model.eval()
    encoder = model.encoder
    sigmoid = nn.Sigmoid()
    num_examples = {i:10 for i in range(10)}
    fig_1, axes_1 = plt.subplots(nrows = 10, ncols = 10, figsize = (10,10))
    fig_2, axes_2 = plt.subplots(nrows = 10, ncols = 10, figsize = (10,10))
    embed = []
    labels = []
    i = 0
    with torch.no_grad():
        pbar = tqdm(dataloader, total = len(dataloader), leave = False)
        for imgs, lbls in pbar:
            imgs, lbls = imgs.to(device), lbls.to(device)
            imgs_inp = imgs.view(-1, 1, 28, 28)
            embeddings = encoder(imgs_inp)
            out = sigmoid(model(imgs))
            out = out.view(len(out), 1, 28, 28)
            embeddings = embeddings.view(len(out), -1)
            embed.extend(embeddings.cpu().detach().numpy().tolist())
            labels.extend(lbls.cpu().detach().numpy().tolist())
            for n in num_examples.keys():
                if num_examples[n] <= 0:
                    continue
                idxs = torch.where(lbls == n)[0][:num_examples[n]]
                num_examples[n] -= len(idxs)
                for idx in idxs:
                    idx = idx.item()
                    img = imgs[idx].detach().cpu().numpy()[0]
                    gen = out[idx].detach().cpu().numpy()[0]
                    axes_1[i//10, i%10].imshow(img, cmap = 'gray')
                    axes_1[i//10, i%10].axis('off')
                    axes_2[i//10, i%10].imshow(gen, cmap = 'gray')
                    axes_2[i//10, i%10].axis('off')
                    i+=1

    fig_3, axes = plt.subplots()
    embed = np.array(embed)
    labels = np.array(labels)
    color = plt.get_cmap('Spectral', 10)
    scatter_plot = axes.scatter(embed[:,0], embed[:,1], c = labels, cmap = color)
    plt.colorbar(scatter_plot, drawedges = True, ax = axes)
    # wandb.log({f"{title}_original": fig_1, f"{title}_regenerated":fig_2, f"{title}_visualization": wandb.Image(fig_3)})
    if clear:
        fig_1.clear()
        plt.close(fig_1)
        fig_2.clear()
        plt.close(fig_2)
        fig_3.clear()
        plt.close(fig_3)

def KL_divergence(latent_mean, latent_log_var):
    example_divergence = torch.sum(-0.5*(1 + latent_log_var - torch.square(latent_mean) - torch.exp(latent_log_var)), dim = -1)
    batch_divergence = torch.mean(example_divergence)
    return batch_divergence

KL Divergence Derivation (univariate gaussian):
$$

Given: \\
P \sim \mathcal{N}\left(\mu, \sigma \right) \\
Q \sim \mathcal{N}\left(0, 1 \right) \\

D_{KL} \left(p ||q\right) = \int^{\infty}_{-\infty} \left(p(x)\log{\frac{p(x)}{q(x)}}\right) dx \\

= \int^{\infty}_{-\infty} \left(p(x)\log{\frac{\frac{1}{\cancel{\sqrt{2\pi}}\sigma}\exp^{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)}}{\frac{1}{\cancel{\sqrt{2\pi}}}\exp^{\left(-\frac{\left(x\right)^2}{2}\right)}}}\right) dx \\

= \int^{\infty}_{-\infty} \left(p(x)\log{\frac{\frac{1}{\sigma}\exp^{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)}}{\exp^{\left(-\frac{\left(x\right)^2}{2}\right)}}}\right) dx \\

= \int^{\infty}_{-\infty} \left(p(x)\log\left(\frac{1}{\sigma}\exp^{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)}\right)\right) dx- \int^{\infty}_{-\infty} \left(p(x)\log\left(\exp^{\left(-\frac{\left(x\right)^2}{2}\right)}\right)\right) dx \\

= \int^{\infty}_{-\infty} \left(p(x)\log\left(\frac{1}{\sigma}\exp^{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)}\right)\right)dx- \int^{\infty}_{-\infty} \left(p(x)\left(-\frac{\left(x\right)^2}{2}\right)\right) dx \\

= \int^{\infty}_{-\infty} \left(p(x) \log\left(\frac{1}{\sigma}\right)\right)dx + \int^{\infty}_{-\infty} \left(p(x) \log\left(\exp^{\left(-\frac{\left(x^2 + \mu^2 - 2x\mu\right)}{2\sigma^2}\right)}\right) \right)dx + \frac{1}{2}\int^{\infty}_{-\infty} \left(x^2p(x)\right) dx \\

[\text{Sidenote:} \,\,\,\, \mathbb{E}\left[x\right] = \int^{\infty}_{-\infty} (x.p(x))dx = \mu\\
\mathbb{E}\left[x^2\right] = \int^{\infty}_{-\infty} (x^2p(x))dx \\
\text{Now we know that: } \mathbb{E}\left[(x-\mu)^2\right] = \sigma^2 (variance) = \mathbb{E}\left[x^2\right] - \left(\mathbb{E}\left[x\right]\right)^2 \\ 
= \mathbb{E}\left[x^2\right] = \sigma^2 + \mu^2\,\,\,\,] \\

= \log\left(\frac{1}{\sigma}\right)\int^{\infty}_{-\infty} \left(p(x)\right)dx + \int^{\infty}_{-\infty} \left(p(x) \left(-\frac{\left(x^2 + \mu^2 - 2x\mu\right)}{2\sigma^2}\right) \right)dx + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

\text{Since p(x) is a pdf, summing over entire x is 1} \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2\sigma^2}\int^{\infty}_{-\infty}p(x) \left(x^2 + \mu^2 - 2x\mu\right)dx + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2\sigma^2}\left(\int^{\infty}_{-\infty}p(x)x^2dx + \mu^2\int^{\infty}_{-\infty}p(x)dx - 2\mu\int^{\infty}_{-\infty}p(x)xdx\right) + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2\sigma^2}\left(\sigma^2 + \mu^2 + \mu^2 - 2\mu^2\right) + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2\sigma^2}\left(\sigma^2 + \cancel{\mu^2} + \cancel{\mu^2} + \cancel{2\mu^2}\right) - \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2\cancel{\sigma^2}}\left(\cancel{\sigma^2}\right) + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \log\left(\frac{1}{\sigma}\right) - \frac{1}{2} + \frac{1}{2} \left(\sigma^2 + \mu^2\right) \\

= \frac{1}{2}\left(-2\log(\sigma) - 1 + \sigma^2 + \mu^2 \right) \\

= -\frac{1}{2}\left(2\log(\sigma) + 1 - \sigma^2 - \mu^2 \right) \\

= -\frac{1}{2}\left(\log(\sigma^2) + 1 - \sigma^2 - \mu^2 \right)
$$

And the above derived is the final formula for calculating KL divergence loss