In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import wandb
import os

os.environ["WANDB_NOTEBOOK_NAME"] = 'generative_models.ipynb'

seed = 10
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(0)

In [2]:
dataset_path = '../datasets/PSC_bandgaps/PSC_bandgaps_dataset.csv'

dataset_df = pd.read_csv(dataset_path)

bandgaps = dataset_df['Gap'].values
elemental_properties = dataset_df.iloc[:, 19:].values

# Creating a pytorch dataset
class BandgapDataset(torch.utils.data.Dataset):
    def __init__(self, bandgaps, elemental_properties):
        self.bandgaps = bandgaps
        self.elemental_properties = elemental_properties
    def __len__(self):
        return len(self.bandgaps)
    def __getitem__(self, idx):
        return self.bandgaps[idx], self.elemental_properties[idx]
    
# Standardize the dataset
elemental_properties = (elemental_properties - elemental_properties.mean(axis=0)) / elemental_properties.std(axis=0)

# Change the dtype to torch.float32
bandgaps = bandgaps.astype(np.float32)
elemental_properties = elemental_properties.astype(np.float32)

# Shuffle and split the dataset into train and validation sets
indices = np.arange(len(bandgaps))
np.random.shuffle(indices)

train_indices = indices[:int(0.8*len(bandgaps))]
val_indices = indices[int(0.8*len(bandgaps)):]
bandgaps_train = bandgaps[train_indices]
bandgaps_val = bandgaps[val_indices]
elemental_properties_train = elemental_properties[train_indices]
elemental_properties_val = elemental_properties[val_indices]
    
train_dataset = BandgapDataset(torch.from_numpy(bandgaps_train), torch.from_numpy(elemental_properties_train))
val_dataset = BandgapDataset(torch.from_numpy(bandgaps_val), torch.from_numpy(elemental_properties_val))

In [3]:
class VAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.activation = torch.nn.ReLU()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
        )
        self.mu = torch.nn.Linear(hidden_dim, latent_dim)
        self.logvar = torch.nn.Linear(hidden_dim, latent_dim)
        self.predictor = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, 1),
            self.activation
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, input_dim),
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        y = self.predictor(z)
        x_recon = self.decoder(z)
        return x_recon, y, z, mu, logvar

Architecture notes:

- Using the sampled z for both prediction and decoding helps in reconstrutcion but the bandgap prediction is high (MAE ~ 0.76 eV)
- Epoch [900/1500], KL Loss: 0.3239647190140775, Reconst Loss: 0.9181102482466236, Pred Loss : 0.7544278908946401,  Total Loss: 1.9965028582664568
- Using the $\mu (X)$ instead as input to predictor does not affect the reconstruction loss much but the bandgap prediction is much better (MAE ~ 0.15 eV)
- Epoch [1400/1500], KL Loss: 0.009583918224820227, Reconst Loss: 0.9987305798317714, Pred Loss : 0.1031322028556248,  Total Loss: 1.1114467009099127

### Deriving the KL divergence loss for unit normal prior

- Lets start with any arbitrary distribution Q(z) and minimize the KL divergence of it with a distribution P(z|X)

$$
\begin{align*}
    D_{KL}(Q(z) || P(z|X)) & = \int Q(z) \log \frac{Q(z)}{P(z|X)} dz \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log \frac{Q(z)}{P(z|X)} ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(z|X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z) + \log P(X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z)] + \log P(X) \\ \\
    \log P(X) - D_{KL}(Q(z) || P(z|X)) & =  E_{Q(z)}[\log P(X|z)] - D_{KL}(Q(z) || P(z)) 
\end{align*}
$$

- Now instead of choosing any distribution for Q(z), it makes sense to choose a distribution for the z variables that depends on X. Hence we can replace Q(z) with Q(z|X).

$$
\begin{align*}
    \log P(X) - D_{KL}(Q(z|X) || P(z|X)) & =  E_{Q(z|X)}[\log P(X|z)] - D_{KL}(Q(z|X) || P(z))
\end{align*}
$$

- The left hand side contains the terms that we want to maximize. The log probability density of X and an error term that measures the deviation between the approximate distribution (Q(z|X)) and the true probability distribution (P(z|X)). To note P(X) is a high dimensional intractable distribution and we don't have access to P(z|X). By having a large enough capacity for Q(z|X) we are pulling it closer to P(z|X), lower the KL divergence term until we are only optimizing for the log probability density of X. 
- The right hand side contains terms that can be optimized via gradient descent. The first term is the expected value of the log likelihood of the data given the latent variables. The second term is the KL divergence between the approximate distribution and the prior distribution. 
- Stochastic gradient descent can be performed on the right hand side by assuming some forms of the distribution. The most common form for the posterior and liklihood is a multivariate Gaussian distribution and for the prior is unit normal distribution. 

$$
\begin{align*}
    D_{KL}(N(\mu_0, \Sigma_0) || N(\mu_1, \Sigma_1)) = \frac{1}{2} ( \text{tr}(\Sigma_1^{-1} \Sigma_0) + (\mu_1 - \mu_0)^T \Sigma_1^{-1} (\mu_1 - \mu_0) - k + \log \frac{\det \Sigma_1}{\det \Sigma_0} )
\end{align*}
$$

- 'k' is the dimensionality of the distribution. Substituting the prior as unit normal distribution, we get the KL divergence loss as
$$
\begin{align*}
    D_{KL}(N(\mu (X), \Sigma (X)) || N(O, I)) = \frac{1}{2} ( \text{tr}(\Sigma (X)) + (\mu (X))^T (\mu (X)) - k - \log \det \Sigma (X) )
\end{align*}
$$

- To back propagate the errors to the the neural network that approximates Q(z|X), so that we get z's that correctly reproduce the data, we need to find a way that allows backpropagation to work. This is where the reparameterization trick comes in. It allows us to sample for 'z' while giving access to the neural networks that approximate the mean and covariance functions for  Q(z|X). $ z = \mu (X) + \Sigma (X) * \epsilon $. Here $\mu (X) and \Sigma (X)$ are approximated by using neural networks and $\epsilon$ is sampled from the unit normal distribution.
- If any other distribution is to be modelled then the KL divergnce term must be modified accordingly and the appropriate reparameterization trick must be used.

Reference:
- Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 https://arxiv.org/abs/1312.6114 (Appendix B)
- Doersch, C. Tutorial on Variational Autoencoders. arXiv January 3, 2021. http://arxiv.org/abs/1606.05908.


### Trail 2 : Using a Gaussian Mixture Model as prior

Check out these github repos for how to code the model and the loss function :

1. https://github.com/jariasf/GMVAE/tree/master
2. https://github.com/RuiShu/vae-clustering


In [4]:
# # Defining the required layers for the GMVAE
# # Sample from the Gumbel-Softmax distribution and optionally discretize.
# class GumbelSoftmax(torch.nn.Module):

#   def __init__(self, f_dim, c_dim):
#     super(GumbelSoftmax, self).__init__()
#     self.logits = nn.Linear(f_dim, c_dim)
#     self.f_dim = f_dim
#     self.c_dim = c_dim
     
#   def sample_gumbel(self, shape, is_cuda=False, eps=1e-20):
#     U = torch.rand(shape)
#     if is_cuda:
#       U = U.cuda()
#     return -torch.log(-torch.log(U + eps) + eps)

#   def gumbel_softmax_sample(self, logits, temperature):
#     y = logits + self.sample_gumbel(logits.size(), logits.is_cuda)
#     return F.softmax(y / temperature, dim=-1)

#   def gumbel_softmax(self, logits, temperature, hard=False):
#     """
#     ST-gumple-softmax
#     input: [*, n_class]
#     return: flatten --> [*, n_class] an one-hot vector
#     """
#     #categorical_dim = 10
#     y = self.gumbel_softmax_sample(logits, temperature)

#     if not hard:
#         return y

#     shape = y.size()
#     _, ind = y.max(dim=-1)
#     y_hard = torch.zeros_like(y).view(-1, shape[-1])
#     y_hard.scatter_(1, ind.view(-1, 1), 1)
#     y_hard = y_hard.view(*shape)
#     # Set gradients w.r.t. y_hard gradients w.r.t. y
#     y_hard = (y_hard - y).detach() + y
#     return y_hard 
  
#   def forward(self, x, temperature=1.0, hard=False):
#     logits = self.logits(x).view(-1, self.c_dim)
#     prob = F.softmax(logits, dim=-1)
#     y = self.gumbel_softmax(logits, temperature, hard)
#     return logits, prob, y

# class GMVAE(torch.nn.Module):
#     def __init__(self, input_dim, hidden_dim, latent_dim, num_components):
#         super(GMVAE, self).__init__()

#         self.input_dim = input_dim
#         self.hidden_dim = hidden_dim
#         self.latent_dim = latent_dim
#         self.num_components = num_components
#         self.activation = torch.nn.ReLU()
#         # p(y) = Cat(y|pi)
#         # p(z) = N(z|0, I)
#         # p(x|y,z) = f(x;y,z,theta)

#         # q(y|x) : Probability of seeing a label y given x
#         self.qy_x = torch.nn.Sequential(
#             torch.nn.Linear(input_dim, hidden_dim),
#             self.activation,
#             torch.nn.Linear(hidden_dim, hidden_dim),
#             self.activation,
#             torch.nn.Linear(hidden_dim, hidden_dim),
#             GumbelSoftmax(hidden_dim, num_components)
#         )
#         # q(z|x,y) : Probability of seeing a latent variable z given x and y
#         self.qz_xy = torch.nn.Sequential(
#             torch.nn.Linear(input_dim + num_components, hidden_dim),
#             self.activation,
#             torch.nn.Linear(hidden_dim, hidden_dim),
#             self.activation,
#             torch.nn.Linear(hidden_dim, hidden_dim),
#             self.activation
#         )


class GMVAE(torch.nn.Module):
    def __init__(self, input_dim, y_dim, hidden_dim, latent_dim, model_activation_fn):
        super(GMVAE, self).__init__()

        # Note to self : This has elem props and bandgaps as input
        self.input_dim = input_dim
        self.y_dim = y_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.activation = model_activation_fn

        # Inference Network
        # q(y|x) : Based on the input, predict the class logit
        self.qy_logit = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, y_dim)
        )
        self.qy = torch.nn.Softmax(dim=1)
        # q(z|x,y) : Probability of seeing a latent based on X and class label
        self.qz_xy = torch.nn.Sequential(
            torch.nn.Linear(input_dim + y_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation
        )
        self.mu = torch.nn.Linear(hidden_dim, latent_dim)
        self.logvar = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, latent_dim),
            torch.nn.Softplus()
        )
        # Generative Network
        # p(z|y) : Probability of seeing a latent given the class label
        self.mu_prior = torch.nn.Linear(y_dim, latent_dim)
        self.logvar_prior = torch.nn.Sequential(
            torch.nn.Linear(y_dim, latent_dim),
            torch.nn.Softplus()
        )
        # p(x|z) : Probability of seeing X given the latent
        self.px_z = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, input_dim)
        )
        # # p(x|z,y) : Probability of seeing X given the latent and class label
        # self.px_zy = torch.nn.Sequential(
        #     torch.nn.Linear(latent_dim + y_dim, hidden_dim),
        #     self.activation,
        #     torch.nn.Linear(hidden_dim, hidden_dim),
        #     self.activation,
        #     torch.nn.Linear(hidden_dim, input_dim)
        # )

    def forward(self, x):
        # Inference Net
        # Shape : [batch_size, y_dim]
        qy_logit = self.qy_logit(x)
        # Note to self : This repo has Gumbel Softmax as last layer : https://github.com/jariasf/GMVAE
        # Shape : [batch_size, y_dim]
        qy = self.qy(qy_logit)

        # Defining a tensor that will store the fixed class label for all members of the batch
        y_ = torch.zeros([x.shape[0], self.y_dim])
        z, mu, logvar, mu_prior, logvar_prior, reconst = [[None] * 10 for i in range(6)]
        for i in range(self.y_dim):
            # Add the class label to the tensor
            y = y_ + torch.eye(self.y_dim)[i]
            # Note to self : The generative model can take the predicted class label as input. This is what is done in the GMVAE repo
            # Note to self : In the Rui Shu repo the class label (y) is provided as a one hot vector. 
            h = torch.cat([x, y], dim=1)
            h = self.qz_xy(h)
            # Shape : batch_size, latent_dim
            mu[i] = self.mu(h)
            logvar[i] = self.logvar(h)
            # Note to self : Can use the reparameterization trick here instead. This gives modified M2 in Rui Shu's repo.
            # Using the predicted mean and logvar sample from a gaussian distribution.
            # z[i] = torch.normal(mu[i], logvar[i].exp().sqrt())
            eps = torch.randn_like(mu[i])
            z[i] = mu[i] + eps*torch.exp(0.5 * logvar[i])
            # Generative Net
            # Prior is a gaussian mixture. Using the fixed class label, the mean and logvar of the 'z' distribution are computed.
            # The posterior distribution aims to match this distribution
            mu_prior[i] = self.mu_prior(y)
            logvar_prior[i] = self.logvar_prior(y)
            # Using the sampled 'z' reconstruct the X
            reconst[i] = self.px_z(z[i])
        return z, mu, logvar, mu_prior, logvar_prior, reconst, qy_logit, qy

### Model training grounds

In [5]:
# Training loop

def log_normal(z, mu, logvar):
    c = torch.tensor(2*np.pi, dtype=torch.float32) 
    return torch.tensor(-0.5, dtype=torch.float32)*torch.sum(torch.log(c) + logvar + (z - mu).pow(2) / logvar.exp(), dim=1)

# Sum of reconstruction loss and KL divergence loss
def labelled_loss(X, reconst, z, mu, logvar, mu_prior, logvar_prior):
    c = torch.tensor(0.1, dtype=torch.float32)
    return torch.nn.MSELoss(reduction='sum')(X, reconst) + log_normal(z, mu, logvar) - log_normal(z, mu_prior, logvar_prior) - torch.log(c)

# Classification loss
def cross_entropy_loss(qy_logit, qy):
    return torch.nn.CrossEntropyLoss(reduction='mean')(qy_logit, qy)

# Derived by assuming posterior is Gaussian and prior is unit normal distribution.
def kl_divergence_loss_fn(mu, logvar):
        return torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)

mse_loss_fn = torch.nn.MSELoss(reduction='mean')
mae_loss_fn = torch.nn.L1Loss(reduction='mean')

train_reconst_loss_per_step = []
train_reconst_loss_per_epoch = []
train_pred_loss_per_step = []
train_pred_loss_per_epoch = []
train_kl_loss_per_step = []
train_kl_loss_per_epoch = []

val_reconst_loss_per_step = []
val_reconst_loss_per_epoch = []
val_pred_loss_per_step = []
val_pred_loss_per_epoch = []
val_kl_loss_per_step = []
val_kl_loss_per_epoch = []

train_classification_loss_per_step = []
train_classification_loss_per_epoch = []
train_labelled_loss_per_step = []
train_labelled_loss_per_epoch = []

val_classification_loss_per_step = []
val_classification_loss_per_epoch = []
val_labelled_loss_per_step = []
val_labelled_loss_per_epoch = []

train_total_loss_per_step = []
train_total_loss_per_epoch = []

val_total_loss_per_step = []
val_total_loss_per_epoch = []

def train(model, train_dataset, val_dataset, epochs, batch_size, learning_rate, print_after):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        for i, (bandgaps, elem_props) in enumerate(train_dataloader):
            input = torch.cat([elem_props, bandgaps.unsqueeze(dim=1)], dim=1)
            #input = elem_props

            optimizer.zero_grad()

            # z, mu, logvar, reconst, bandgap_preds = model(input)

            # loss1 = kl_divergence_loss_fn(mu, logvar)
            # loss2 = mse_loss_fn(reconst, elem_props)
            # loss3 = mae_loss_fn(bandgap_preds, bandgaps.unsqueeze(dim=0))
            # total_loss = loss1 + loss2 + loss3
            
            # train_kl_loss_per_step.append(loss1.item())
            # train_reconst_loss_per_step.append(loss2.item())
            # train_pred_loss_per_step.append(loss3.item())
            # train_total_loss_per_step.append(total_loss.item())

            z, mu, logvar, mu_prior, logvar_prior, reconst, qy_logit, qy = model(input)

            loss1 = cross_entropy_loss(qy_logit, qy)
            train_classification_loss_per_step.append(loss1.item())

            loss2 = [None] * model.y_dim
            for i in range(model.y_dim):
                loss2[i] = torch.mean(qy[:, i]*labelled_loss(input, reconst[i], z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            
            train_labelled_loss_per_step.append(torch.stack(loss2).sum().item())
            total_loss = loss1 + torch.stack(loss2).sum()
            train_total_loss_per_step.append(total_loss.item())

            total_loss.backward()
            optimizer.step()

        # train_kl_loss_per_epoch.append(np.mean(train_kl_loss_per_step))
        # train_reconst_loss_per_epoch.append(np.mean(train_reconst_loss_per_step))
        # train_pred_loss_per_epoch.append(np.mean(train_pred_loss_per_step))
        train_classification_loss_per_epoch.append(np.mean(train_classification_loss_per_step))
        train_labelled_loss_per_epoch.append(np.mean(train_labelled_loss_per_step))
        train_total_loss_per_epoch.append(np.mean(train_total_loss_per_step))
        if epoch % print_after == 0:
            print(f'Epoch [{epoch}/{epochs}')
            print(f'     Train Classif Loss: {train_classification_loss_per_epoch[-1]}, Train Labelled Loss: {train_labelled_loss_per_epoch[-1]}, Train Total Loss: {train_total_loss_per_epoch[-1]}')

        # Run the validation loop
        for i, (bandgaps, elem_props) in enumerate(val_dataloader):
            # Concatenate the bandgaps and elem_props 
            input = torch.cat([elem_props, bandgaps.unsqueeze(dim=1)], dim=1)

            z, mu, logvar, mu_prior, logvar_prior, reconst, qy_logit, qy = model(input)

            loss1 = cross_entropy_loss(qy_logit, qy)
            val_classification_loss_per_step.append(loss1.item())

            loss2 = [None] * model.y_dim
            for i in range(model.y_dim):
                loss2[i] = torch.mean(qy[:, i]*labelled_loss(input, reconst[i], z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            
            val_labelled_loss_per_step.append(torch.stack(loss2).sum().item())
            total_loss = loss1 + torch.stack(loss2).sum()
            val_total_loss_per_step.append(total_loss.item())

        # val_kl_loss_per_epoch.append(np.mean(train_kl_loss_per_step))
        # val_reconst_loss_per_epoch.append(np.mean(train_reconst_loss_per_step))
        # val_pred_loss_per_epoch.append(np.mean(train_pred_loss_per_step))
        val_classification_loss_per_epoch.append(np.mean(val_classification_loss_per_step))
        val_labelled_loss_per_epoch.append(np.mean(val_labelled_loss_per_step))
        val_total_loss_per_epoch.append(np.mean(val_total_loss_per_step))
        if epoch % print_after == 0:
            print(f'    Val. Classif Loss: {val_classification_loss_per_epoch[-1]}, Val. Labelled Loss: {val_labelled_loss_per_epoch[-1]}, Val. Total Loss: {val_total_loss_per_epoch[-1]}')

# input_dim = elemental_properties.shape[1]
# hidden_dim = 50
# latent_dim = 2
# batch_size = 1
# learning_rate = 1e-3
# epochs = 1500
# print_after = 1
# model = VAE(input_dim, hidden_dim, latent_dim)

hyperparams = {
    'model_activation_fn': torch.nn.ReLU(),
    'input_dim': elemental_properties.shape[1] + 1,
    'y_dim':3,
    'hidden_dim': 50,
    'latent_dim':2,
    'batch_size': 1,
    'learning_rate': 1e-3,
    'epochs': 1500,
}
print_after = 100
run = wandb.init(project='gmvae', name='trial1', config=hyperparams, job_type='training', notes='NA')
model = GMVAE(hyperparams['input_dim'], hyperparams['y_dim'], hyperparams['hidden_dim'], hyperparams['latent_dim'], hyperparams['model_activation_fn'])

train(model, train_dataset, val_dataset, hyperparams['epochs'], hyperparams['batch_size'], hyperparams['learning_rate'], print_after)

model_save_name = 'gmvae_trail1.pt'
torch.save(model.state_dict(), model_save_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnthota2[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch [0/1500
     Train Classif Loss: 0.3050380939599149, Train Labelled Loss: 40.06689712427839, Train Total Loss: 40.37193523358695
    Val. Classif Loss: 0.004014177042464702, Val. Labelled Loss: 32.52751094163066, Val. Total Loss: 32.53152513985682


KeyboardInterrupt: 

In [None]:
# Plotting the 2D latent space
#elem_props_reconst, bandgap_preds, z, mu, logvar = model(elemental_properties)
z, mu, logvar, mu_prior, logvar_prior, reconst, qy_logit, qy = model(torch.cat([torch.from_numpy(elemental_properties, dtype=torch.float32), torch.from_numpy(bandgaps, dtype=torch.float32)], dim=1))

plt.scatter(z[:, 0].detach().numpy(), z[:, 1].detach().numpy(), c=bandgaps, cmap='viridis')
plt.colorbar()
plt.show()

### Observations for Unit normal prior
- What we observe from the above example is that although multivariate Gaussian distribution are useful
    as each dimension can encode a separate DOF which results in representations that are sturctured and disentangled, 
    they are unimodal and hence cannot encode complex representations. A natural extension is to then use a different
    prior. Gaussain Mixture Model (GMM) is the next choice.
- Latent space is segregated into different classes.
- However, inference is non-trivial.

In [5]:
import uuid

print(uuid.uuid1())

979b4368-d0f6-11ee-a068-7a2a0020f23d
