In [1]:
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import wandb
import os
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torchmetrics.classification import MultilabelAccuracy
import copy
import joblib
import shutil
import shap
import seaborn as sns

from wandb_api_key import api_key
from utils.custom_utils import set_global_random_seed
from utils.dataset_utils import create_preprocessed_datasets

inputs_dir = 'inputs_syn_grid_lin_data'
run_dir = 'syn_grid_lin_data'
nn = 1
mode = 'train'
accelerator = 'cpu'
kfolds =  5

# Login to wandb. Create a wandb account and get the api key from the user settings tab
user_name = 'nthota2'
project_name = 'results_for_RL_paper'
os.environ["WANDB_NOTEBOOK_NAME"] = 'hyperparam_search.ipynb'
os.environ["WANDB_API_KEY"] = api_key
# dryrun = Does not store any weights and bias data locally
# online = enables cloud syncing
os.environ["WANDB_MODE"] = "online"

# Plotting parameters
fig_aspect_ratio = 1/1.3

# # Font style
# plt.rcParams.update({
# "text.usetex":True,
# "font.family":"serif",
# "font.serif":["Computer Modern Roman"]})

plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['xtick.major.size'] = 5
plt.rcParams['xtick.major.width'] = 1
plt.rcParams['xtick.minor.size'] = 5
plt.rcParams['xtick.minor.width'] = 1
plt.rcParams['ytick.major.size'] = 5
plt.rcParams['ytick.major.width'] = 1
plt.rcParams['ytick.minor.size'] = 5
plt.rcParams['ytick.minor.width'] = 1

plt.rcParams['axes.labelsize'] = 10
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['legend.fontsize'] = 10

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Creating a run directory in '../runs' to store the training results
if not os.path.exists(f'../runs/{run_dir}'):
    os.makedirs(f'../runs/{run_dir}')
# Find the inputs directory and import the dataset_inputs.py and train_inputs.py files
sys.path.append(inputs_dir)
from dataset_inputs import list_of_nn_datasets_dict
from train_inputs import list_of_nn_train_params_dict
# Go back to the original directory
sys.path.append('..')

## Load the dataset

In [None]:
import networkx as nx

plt.rcParams['text.usetex'] = True

# Create a Directed Graph
G = nx.DiGraph()

z_labels = ['$z_1$', '$z_2$', '$z_3$', '$z_4$', '$z_5$', '$z_6$', '$z_7$', '$z_8$']
f1_labels = ['$f_1$', '$f_2$', '$f_3$', '$f_4$']

for z in z_labels:
    for f in f1_labels:
        G.add_edge(z, f)

G.add_edge('$f_1$', '$f_5$')
G.add_edge('$f_2$', '$f_5$')
G.add_edge('$f_3$', '$f_5$')
G.add_edge('$f_4$', '$f_5$')

# Create a layout for our nodes 
pos = {'$z_1$': (1, 1), '$z_2$': (2, 1), '$z_3$': (3, 1), '$z_4$': (4, 1), 
       '$z_5$': (5, 1), '$z_6$': (6, 1), '$z_7$': (7, 1), '$z_8$': (8, 1),
       '$f_1$': (1.5, 2), '$f_2$': (3.5, 2), '$f_3$': (5.5, 2), '$f_4$': (7.5, 2),
       '$f_5$': (4.5, 3)}

# Draw the graph using the layout
nx.draw(G, pos, with_labels=True, arrowsize=15, node_size=1000, font_size=20, node_color='skyblue', font_color='black', font_weight='bold')

plt.subplots_adjust(left=0, right=2, top=2, bottom=0)
# Save the plot to pdf 
plt.savefig('multiscale_graph.pdf', dpi=300, bbox_inches='tight', )

# Show the plot
plt.show()

## Build the model

In [None]:
# Test the simple autoencoder
randX = torch.rand(2, 2)
# self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn
simple_ae = UnsupervisedSimpleAE(15, 10, 0.01, 2, 3, 'tanh', None)
print(simple_ae.decoder(randX))

### Deriving the KL divergence loss for unit normal prior

- Lets start with any arbitrary distribution Q(z) and minimize the KL divergence of it with a distribution P(z|X)

$$
\begin{align*}
    D_{KL}(Q(z) || P(z|X)) & = \int Q(z) \log \frac{Q(z)}{P(z|X)} dz \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log \frac{Q(z)}{P(z|X)} ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(z|X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z) + \log P(X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z)] + \log P(X) \\ \\
    \log P(X) - D_{KL}(Q(z) || P(z|X)) & =  E_{Q(z)}[\log P(X|z)] - D_{KL}(Q(z) || P(z)) 
\end{align*}
$$

- Now instead of choosing any distribution for Q(z), it makes sense to choose a distribution for the z variables that depends on X. Hence we can replace Q(z) with Q(z|X).

$$
\begin{align*}
    \log P(X) - D_{KL}(Q(z|X) || P(z|X)) & =  E_{Q(z|X)}[\log P(X|z)] - D_{KL}(Q(z|X) || P(z))
\end{align*}
$$

- The left hand side contains the terms that we want to maximize. The log probability density of X and an error term that measures the deviation between the approximate distribution (Q(z|X)) and the true probability distribution (P(z|X)). To note P(X) is a high dimensional intractable distribution and we don't have access to P(z|X). By having a large enough capacity for Q(z|X) we are pulling it closer to P(z|X), lower the KL divergence term until we are only optimizing for the log probability density of X. 
- The right hand side contains terms that can be optimized via gradient descent. The first term is the expected value of the log likelihood of the data given the latent variables. The second term is the KL divergence between the approximate distribution and the prior distribution. 
- Stochastic gradient descent can be performed on the right hand side by assuming some forms of the distribution. The most common form for the posterior and liklihood is a multivariate Gaussian distribution and for the prior is unit normal distribution. 

$$
\begin{align*}
    D_{KL}(N(\mu_0, \Sigma_0) || N(\mu_1, \Sigma_1)) = \frac{1}{2} ( \text{tr}(\Sigma_1^{-1} \Sigma_0) + (\mu_1 - \mu_0)^T \Sigma_1^{-1} (\mu_1 - \mu_0) - k + \log \frac{\det \Sigma_1}{\det \Sigma_0} )
\end{align*}
$$

- 'k' is the dimensionality of the distribution. Substituting the prior as unit normal distribution, we get the KL divergence loss as
$$
\begin{align*}
    D_{KL}(N(\mu (X), \Sigma (X)) || N(O, I)) = \frac{1}{2} ( \text{tr}(\Sigma (X)) + (\mu (X))^T (\mu (X)) - k - \log \det \Sigma (X) )
\end{align*}
$$

- To back propagate the errors to the the neural network that approximates Q(z|X), so that we get z's that correctly reproduce the data, we need to find a way that allows backpropagation to work. This is where the reparameterization trick comes in. It allows us to sample for 'z' while giving access to the neural networks that approximate the mean and covariance functions for  Q(z|X). $ z = \mu (X) + \Sigma (X) * \epsilon $. Here $\mu (X) and \Sigma (X)$ are approximated by using neural networks and $\epsilon$ is sampled from the unit normal distribution.
- If any other distribution is to be modelled then the KL divergnce term must be modified accordingly and the appropriate reparameterization trick must be used.

Reference:
- Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 https://arxiv.org/abs/1312.6114 (Appendix B)
- Doersch, C. Tutorial on Variational Autoencoders. arXiv January 3, 2021. http://arxiv.org/abs/1606.05908.


## Model Optimization

In [None]:
# Training loop

def log_normal(z, mu, logvar):
    c = torch.tensor(2*np.pi, dtype=torch.float32) 
    return torch.tensor(-0.5, dtype=torch.float32)*torch.sum(torch.log(c) + logvar + (z - mu).pow(2) / logvar.exp(), dim=1)

# KL divergence loss
def labelled_loss(z, mu, logvar, mu_prior, logvar_prior):
    c = torch.tensor(0.1, dtype=torch.float32)
    return log_normal(z, mu, logvar) - log_normal(z, mu_prior, logvar_prior) - torch.log(c)

# Derived by assuming posterior is Gaussian and prior is unit normal distribution.
def kl_divergence_loss_fn(mu, logvar):
        return torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)

def train(train_subsampler, val_subsampler, model_type, run_dir, sweep_hyperparams=True, save_model=False, project_name=None, model_name=None, config=None):
    if sweep_hyperparams:
        run = wandb.init(job_type='training', resume=False, reinit=False, config=sweep_config)
        config = wandb.config
        input_dim = config.input_dim
        hidden_dim = config.hidden_dim
        y_dim = config.y_dim
        dropout = config.dropout
        l1_reg = config.l1_reg
        l2_reg = config.l2_reg
        latent_dim = config.latent_dim
        num_layers = config.num_layers
        activation_fn = config.activation_fn
        pred_activation_fn = config.pred_activation_fn
        dec_activation_fn = config.dec_activation_fn
        lr = config.learning_rate
        epochs = config.epochs
        batch_size = config.batch_size
    else:
        run = wandb.init(project=project_name, name=model_name, id=model_name, job_type='training', resume=False, reinit=False, config=config)
        input_dim = config['input_dim']['value']
        hidden_dim = config['hidden_dim']['value']
        y_dim = config['y_dim']['value']
        dropout = config['dropout']['value']
        l1_reg = config['l1_reg']['value']
        l2_reg = config['l2_reg']['value']
        latent_dim = config['latent_dim']['value']
        num_layers = config['num_layers']['value']
        activation_fn = config['activation_fn']['value']
        pred_activation_fn = config['pred_activation_fn']['value']
        dec_activation_fn = config['dec_activation_fn']['value']
        lr = config['learning_rate']['value']
        epochs = config['epochs']['value']
        batch_size = config['batch_size']['value']
    
    if model_type == 'SupervisedSimpleAE':
        model = SupervisedSimpleAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn, dec_activation_fn)
    elif model_type == 'UnsupervisedSimpleAE':
        model = UnsupervisedSimpleAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
    elif model_type == 'SupervisedVAE':
        model = SupervisedVAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    elif model_type == 'UnsupervisedVAE':
        model = UnsupervisedVAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    elif model_type == 'GMVAE':
        model = GMVAE(input_dim, y_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    else:
        raise ValueError('Invalid model type')
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_dataloader = torch.utils.data.DataLoader(train_subsampler, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_subsampler, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        train_kl_loss_per_step = []

        train_recon1_loss_per_step = []
        train_recon2_loss_per_step = []
        train_recon3_loss_per_step = []
        train_recon4_loss_per_step = []

        train_recon1_acc_per_step = []
        train_recon2_acc_per_step = []
        train_recon3_acc_per_step = []

        train_pred_loss_per_step = []

        train_total_loss_per_step = []

        val_kl_loss_per_step = []

        val_recon1_loss_per_step = []
        val_recon2_loss_per_step = []
        val_recon3_loss_per_step = []
        val_recon4_loss_per_step = []

        val_recon1_acc_per_step = []
        val_recon2_acc_per_step = []
        val_recon3_acc_per_step = []

        val_pred_loss_per_step = []

        val_total_loss_per_step = []

        # for i, input in enumerate(train_dataloader):
        for i, (y, input) in enumerate(train_dataloader):
        # for i, (input1, input2, input3, input4) in enumerate(train_dataloader):
            optimizer.zero_grad()

            ## ------------------------ USER INPUT START ------------------------

            # Model 1
            # _, recon1, recon2, recon3, recon4 = model(torch.concat((input1, input2, input3, input4), dim=1))
            _, pred, recon = model(input)
            # _, reconst = model(input)
            train_loss1 = torch.nn.L1Loss(reduction='mean')(recon, input)
            train_recon1_loss_per_step.append(train_loss1.item())
            
            # train_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(recon1, input1)
            # train_acc1_metric = MultilabelAccuracy(num_labels=3)
            # train_acc1 = train_acc1_metric(recon1, input1)
            # train_recon1_loss_per_step.append(train_loss1.item())
            # train_recon1_acc_per_step.append(train_acc1.item())

            # train_loss2 = torch.nn.CrossEntropyLoss(reduction='mean')(recon2, input2)
            # train_acc2_metric = MultilabelAccuracy(num_labels=4)
            # train_acc2 = train_acc2_metric(recon2, input2)
            # train_recon2_loss_per_step.append(train_loss2.item())
            # train_recon2_acc_per_step.append(train_acc2.item())

            # train_loss3 = torch.nn.CrossEntropyLoss(reduction='mean')(recon3, input3)
            # train_acc3_metric = MultilabelAccuracy(num_labels=8)
            # train_acc3 = train_acc3_metric(recon3, input3)
            # train_recon3_loss_per_step.append(train_loss3.item())
            # train_recon3_acc_per_step.append(train_acc3.item())

            # train_loss4 = torch.nn.L1Loss(reduction='mean')(pred, y)
            # train_recon4_loss_per_step.append(train_loss4.item())

            train_loss3 = torch.nn.L1Loss(reduction='mean')(pred, y)
            train_pred_loss_per_step.append(train_loss3.item())

            enc_params = torch.cat([x.view(-1) for x in model.encoder.parameters()])
            pred_params = torch.cat([x.view(-1) for x in model.predictor.parameters()])
            dec_params = torch.cat([x.view(-1) for x in model.decoder.parameters()])

            l1_regularization = l1_reg * (torch.norm(enc_params, 1) + 
                                            torch.norm(pred_params, 1) +
                                          torch.norm(dec_params, 1))
            l2_regularization = l2_reg * (torch.norm(enc_params, 2) + 
                                            torch.norm(pred_params, 2) +
                                          torch.norm(dec_params, 2))
            
            train_total_loss = train_loss1 + train_loss3 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + train_loss2 + train_loss3 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + train_loss2 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + l1_regularization + l2_regularization

            ## ------------------------ USER INPUT END ------------------------

            train_total_loss_per_step.append(train_total_loss.item())

            # # Model 2
            # z, pred, reconst, mu, logvar = model(input)
            # train_loss1 = kl_divergence_loss_fn(mu, logvar)
            # train_kl_loss_per_step.append(train_loss1.item())
            # train_loss2 = torch.nn.MSELoss(reduction='mean')(input, reconst)
            # train_reconst_loss_per_step.append(train_loss2.item())
            # train_loss3 = torch.nn.MSELoss(reduction='mean')(bandgaps.unsqueeze(dim=1), pred)
            # train_pred_loss_per_step.append(train_loss3.item())
            # train_total_loss = train_loss1 + train_loss2 + train_loss3
            # train_total_loss_per_step.append(train_total_loss.item())

            # # Model 3
            # z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(input)
            # train_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(qy_logit, qy)
            # train_loss2 = [None] * model.y_dim
            # train_loss3 = [None] * model.y_dim
            # train_loss4 = [None] * model.y_dim
            # for i in range(model.y_dim):
            #     # Take mean across the batch
            #     train_loss2[i] = torch.mean(qy[:, i]*torch.nn.L1Loss(reduction='sum')(bandgaps.unsqueeze(dim=1), pred[i]), dtype=torch.float32)
            #     train_loss3[i] = torch.mean(qy[:, i]*torch.nn.MSELoss(reduction='sum')(input, reconst[i]), dtype=torch.float32)
            #     train_loss4[i] = torch.mean(qy[:, i]*labelled_loss(z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            # train_pred_loss_per_step.append(torch.stack(train_loss2).sum().item())
            # train_reconst_loss_per_step.append(torch.stack(train_loss3).sum().item())
            # train_kl_loss_per_step.append(torch.stack(train_loss4).sum().item())
            # train_total_loss = train_loss1 + torch.stack(train_loss2).sum() + torch.stack(train_loss3).sum() + torch.stack(train_loss4).sum()
            # train_total_loss_per_step.append(train_total_loss.item())

            train_total_loss.backward()
            optimizer.step()
        # if model_type == 'GMVAE':
        #     wandb.log({'epoch':epoch, 'train_kl_loss_per_epoch':np.mean(train_kl_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_recon1_loss_per_epoch':np.mean(train_recon1_loss_per_step)})
        # wandb.log({'epoch':epoch, 'train_recon1_acc_per_epoch':np.mean(train_recon1_acc_per_step)})

        # wandb.log({'epoch':epoch, 'train_recon2_loss_per_epoch':np.mean(train_recon2_loss_per_step)})
        # wandb.log({'epoch':epoch, 'train_recon2_acc_per_epoch':np.mean(train_recon2_acc_per_step)})

        # wandb.log({'epoch':epoch, 'train_recon3_loss_per_epoch':np.mean(train_recon3_loss_per_step)})
        # wandb.log({'epoch':epoch, 'train_recon3_acc_per_epoch':np.mean(train_recon3_acc_per_step)})

        # wandb.log({'epoch':epoch, 'train_recon4_loss_per_epoch':np.mean(train_recon4_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_pred_loss_per_epoch':np.mean(train_pred_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_total_loss_per_epoch':np.mean(train_total_loss_per_step)})

        # Run the validation loop
        # for i, input in enumerate(val_dataloader):
        for i, (y, input) in enumerate(val_dataloader):
        # for i, (y, input1, input2) in enumerate(val_dataloader):
        # for i, (input1, input2, input3, input4) in enumerate(val_dataloader):

            ## ------------------------ USER INPUT START ------------------------

            # Model 1
            # _, recon1, recon2, recon3, recon4 = model(torch.concat((input1, input2, input3, input4), dim=1))
            # _, pred, reconst1, reconst2 = model(torch.concat((input1, input2), dim=1))
            _, pred, recon = model(input)
            # _, reconst = model(input)

            val_loss1 = torch.nn.L1Loss(reduction='mean')(recon, input)
            val_recon1_loss_per_step.append(val_loss1.item())

            # val_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(recon1, input1)
            # val_recon1_loss_per_step.append(val_loss1.item())
            # val_acc1_metric = MultilabelAccuracy(num_labels=3)
            # val_acc1 = val_acc1_metric(recon1, input1)
            # val_recon1_acc_per_step.append(val_acc1.item())

            # val_loss2 = torch.nn.CrossEntropyLoss(reduction='mean')(recon2, input2)
            # val_recon2_loss_per_step.append(val_loss2.item())
            # val_acc2_metric = MultilabelAccuracy(num_labels=4)
            # val_acc2 = val_acc2_metric(recon2, input2)
            # val_recon2_acc_per_step.append(val_acc2.item())

            # val_loss3 = torch.nn.CrossEntropyLoss(reduction='mean')(recon3, input3)
            # val_recon3_loss_per_step.append(val_loss3.item())
            # val_acc3_metric = MultilabelAccuracy(num_labels=8)
            # val_acc3 = val_acc3_metric(recon3, input3)
            # val_recon3_acc_per_step.append(val_acc3.item())

            # val_loss4 = torch.nn.L1Loss(reduction='mean')(recon4, input4)
            # val_recon4_loss_per_step.append(val_loss4.item())

            val_loss3 = torch.nn.L1Loss(reduction='mean')(pred, y)
            val_pred_loss_per_step.append(val_loss3.item())

            val_total_loss = val_loss1 + val_loss3

            ## ------------------------ USER INPUT END ------------------------

            val_total_loss_per_step.append(val_total_loss.item())
            # val_total_loss = val_loss1 + val_loss2
            # val_total_loss_per_step.append(val_total_loss.item())
            # val_total_loss = val_loss1
            # val_total_loss_per_step.append(val_total_loss.item())

            # # Model 2
            # z, pred, reconst, mu, logvar = model(input)
            # val_loss1 = kl_divergence_loss_fn(mu, logvar)
            # val_kl_loss_per_step.append(val_loss1.item())
            # val_loss2 = torch.nn.MSELoss(reduction='mean')(input, reconst)
            # val_reconst_loss_per_step.append(val_loss2.item())
            # val_loss3 = torch.nn.MSELoss(reduction='mean')(bandgaps.unsqueeze(dim=1), pred)
            # val_pred_loss_per_step.append(val_loss3.item())
            # val_total_loss = val_loss1 + val_loss2 + val_loss3
            # val_total_loss_per_step.append(val_total_loss.item())

            # # Model 3
            # z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(input)
            # val_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(qy_logit, qy)
            # val_loss2 = [None] * model.y_dim
            # val_loss3 = [None] * model.y_dim
            # val_loss4 = [None] * model.y_dim
            # for i in range(model.y_dim):
            #     # Take mean across batch
            #     val_loss2[i] = torch.mean(qy[:, i]*torch.nn.L1Loss(reduction='sum')(bandgaps.unsqueeze(dim=1), pred[i]), dtype=torch.float32)
            #     val_loss3[i] = torch.mean(qy[:, i]*torch.nn.MSELoss(reduction='sum')(input, reconst[i]), dtype=torch.float32)
            #     val_loss4[i] = torch.mean(qy[:, i]*labelled_loss(z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            # val_pred_loss_per_step.append(torch.stack(val_loss2).sum().item())
            # val_reconst_loss_per_step.append(torch.stack(val_loss3).sum().item())
            # val_kl_loss_per_step.append(torch.stack(val_loss4).sum().item())
            # val_total_loss = val_loss1 + torch.stack(val_loss2).sum() + torch.stack(val_loss3).sum() + torch.stack(val_loss4).sum()
            # val_total_loss_per_step.append(val_total_loss.item())

        # if model_type == 'GMVAE':
        #     wandb.log({'epoch':epoch, 'val_kl_loss_per_epoch':np.mean(val_kl_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_recon1_loss_per_epoch': np.mean(val_recon1_loss_per_step)})
        # wandb.log({'epoch':epoch, 'val_recon1_acc_per_epoch': np.mean(val_recon1_acc_per_step)})

        # wandb.log({'epoch':epoch, 'val_recon2_loss_per_epoch': np.mean(val_recon2_loss_per_step)})
        # wandb.log({'epoch':epoch, 'val_recon2_acc_per_epoch': np.mean(val_recon2_acc_per_step)})

        # wandb.log({'epoch':epoch, 'val_recon3_loss_per_epoch': np.mean(val_recon3_loss_per_step)})
        # wandb.log({'epoch':epoch, 'val_recon3_acc_per_epoch': np.mean(val_recon3_acc_per_step)})

        # wandb.log({'epoch':epoch, 'val_recon4_loss_per_epoch': np.mean(val_recon4_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_pred_loss_per_epoch': np.mean(val_pred_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_total_loss_per_epoch': np.mean(val_total_loss_per_step)})
    run.finish()
    # Save the model
    if save_model:
        model_dir = f'../runs/{run_dir}/{model_name}'
        # if model dir does not exist create it
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        model_path = model_dir + '/' + f'{model_name}.pth'
        torch.save(model.state_dict(), model_path)
        # Save the scalers also in the model folder
        # scaler_X1_path = f'../runs/perovskite_multiscale_dataset_v2/{model_name}/scaler_X1.pkl'
        scaler_X4_path = f'../runs/{run_dir}/{model_name}/scaler_latentsfromAE2.pkl'
        # scaler_y_path = f'../runs/perovskite_multiscale_dataset_v2/{model_name}/scaler_y.pkl'
        joblib.dump(scaler_X1, scaler_X4_path)
        if model_type == 'SupervisedSimpleAE':
            pass
            # joblib.dump(scaler_y, scaler_y_path)
    
    
## ------------------------ USER INPUT START ------------------------
# By default save_model is set to False for hyperparam runs.

# Step 1
sweep_for_hyperparams = False
# Step 2
sweep_for_latent = False
# Step 3 : Save the best performing model
save_model_on_local = True

ae_name = 'SupSimpleAE_2'
model_type = 'SupervisedSimpleAE'
dataset_name = 'randomData_sumf5'
run_dir = 'results_for_RL_paper'

val_split = 0.2 # num_folds*val_split must be 1 

if sweep_for_hyperparams:
    sweep_type = 'grid' # Select between 'bayes', 'grid', 'random' 
    limit_num_trials_in_sweep = None # Typically 3*3*3=27 trials. Consider only (10% of space is explored) 0.1*81=8.1 ~ 8 trials
    num_folds = 5 # 494/5 = 98.8 in internal validation set
    parameters = {
            'input_dim':{
                'value':X1_final.shape[1] + X2_final.shape[1] + X3_final.shape[1] + X4_final.shape[1]
            },
            'hidden_dim':{
                'values':[25, 50, 75]
            },
            'latent_dim':{
                'value':4
            },
            'y_dim':{
                'value':None
            },
            'dropout':{
                'value':0
            },
            'l1_reg':{
                'value':0
            },
            'l2_reg':{
                'value':0.001
            },
            'num_layers':{
                'values':[1, 2, 3]
            },
            'activation_fn':{
                'values':['tanh', 'relu', None]
            },
            'dec_activation_fn':{
                'value':None
            },
            'pred_activation_fn':{
                'value':None
            },
            'batch_size':{
                'value':10
            },
            'learning_rate':{
                'value':1e-3
            },
            'epochs':{
                'value':1000
            }
        }

if sweep_for_latent or save_model_on_local:
    losses_to_track = ['total']
    latent_space_to_sweep = [2, 4, 6, 8, 10, 12]
    seeds_to_sweep = [0, 1, 2, 3, 4]
    parameters = {
                'input_dim':{
                    'value':X1_final.shape[1]
                },
                'hidden_dim':{
                    'value':None,
                },
                'latent_dim':{
                    'value':10
                },
                'y_dim':{
                    'value':0
                },
                'dropout':{
                    'value':0
                },
                'l1_reg':{
                    'value':0
                },
                'l2_reg':{
                    'value':0.001
                },
                'num_layers':{
                    'value':1
                },
                'activation_fn':{
                    'value':None
                },
                'dec_activation_fn':{
                    'value':None
                },
                'pred_activation_fn':{
                    'value':None
                },
                'batch_size':{
                    'value':10
                },
                'learning_rate':{
                    'value':1e-3
                },
                'epochs':{
                    'value':1500
                }
            }

## ------------------------ USER INPUT END ------------------------

if sweep_for_hyperparams:
    ldim = parameters['latent_dim']['value']
    sweep_config = {
        'name': f'{ae_name}_ldim{ldim}_{dataset_name}',
        'method':sweep_type,
        'metric':{
            'name':'val_total_loss_per_epoch',
            'goal':'minimize'
            },
        'parameters':parameters
    }
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=orig_seed)
    for fold, (train_indices, val_indices) in enumerate(kf.split(dataset)):
        train_subsampler = torch.utils.data.Subset(dataset, train_indices)
        val_subsampler = torch.utils.data.Subset(dataset, val_indices)
        sweep_config['name'] = sweep_config['name'] + '_fold_' + str(fold)
        sweep_id = wandb.sweep(sweep_config, project=project_name)
        run_name = wandb.util.generate_id()
        wandb.agent(sweep_id, lambda: train(train_subsampler, val_subsampler, model_type=model_type, run_dir=run_dir), project=project_name, count=limit_num_trials_in_sweep)
        # Finish the sweep
        wandb.finish()
elif sweep_for_latent:
    # Create a dictionary to store the losses
    loss_dict = {}
    for loss_name in losses_to_track:
        loss_dict.update({f'val_{loss_name}_loss':np.zeros((len(seeds_to_sweep), len(latent_space_to_sweep)))})
        loss_dict.update({f'train_{loss_name}_loss':np.zeros((len(seeds_to_sweep), len(latent_space_to_sweep)))})
    for seed in seeds_to_sweep:
        for i, latent in enumerate(latent_space_to_sweep):
            parameters['latent_dim']['value'] = latent
            model_name = f'{ae_name}_ldim{latent}_{dataset_name}_seed{seed}'
            print(f'Running model {model_name}')
            num_samples = X1.shape[0]
            indices = np.arange(num_samples)
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            np.random.shuffle(indices)
            train_indices = indices[:int((1- val_split)*num_samples)]
            val_indices = indices[int((1- val_split)*num_samples):]
            train_subsampler = torch.utils.data.Subset(dataset, train_indices)
            val_subsampler = torch.utils.data.Subset(dataset, val_indices)
            train(train_subsampler, val_subsampler, model_type=model_type, run_dir=run_dir, sweep_hyperparams=False, save_model=save_model_on_local, project_name=project_name, model_name=model_name, config=parameters)
            # Extract the losses from the wandb run
            api = wandb.Api()     
            run = api.run(f'{project_name}/{model_name}')
            # Extract the losses
            for loss_name in losses_to_track:
                loss_dict[f'val_{loss_name}_loss'][seed, i] = run.summary[f'val_{loss_name}_loss_per_epoch']
                loss_dict[f'train_{loss_name}_loss'][seed, i] = run.summary[f'train_{loss_name}_loss_per_epoch']
    # Store the loss matrices in run directory
    for loss_name in losses_to_track:
        val_loss_path = f'../runs/{run_dir}/val_{loss_name}_loss.npy'
        train_loss_path = f'../runs/{run_dir}/train_{loss_name}_loss.npy'
        np.save(val_loss_path, loss_dict[f'val_{loss_name}_loss'])
        np.save(train_loss_path, loss_dict[f'train_{loss_name}_loss'])
    # recon4_val_loss_path = f'../runs/{run_dir}/val_recon4_loss_for_peroveff2_film_valsplit0p1_bs1.npy'
    # recon4_train_loss_path = f'../runs/{run_dir}/train_recon4_loss_for_peroveff2_film_valsplit0p1_bs1.npy'
    # total_val_loss_path = f'../runs/{run_dir}/val_total_loss_for_peroveff2_film_valsplit0p1_bs1.npy'
    # total_train_loss_path = f'../runs/{run_dir}/train_total_loss_for_peroveff2_film_valsplit0p1_bs1.npy'
    # np.save(recon4_val_loss_path, val_recon4_loss_for_peroveff2_film_valsplit0p1_bs1)
    # np.save(recon4_train_loss_path, train_recon4_loss_for_peroveff2_film_valsplit0p1_bs1)
    # np.save(total_val_loss_path, val_total_loss_for_peroveff2_film_valsplit0p1_bs1)
    # np.save(total_train_loss_path, train_total_loss_for_peroveff2_film_valsplit0p1_bs1)
    # Reset the random seed to what it was before
    seed = orig_seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
elif save_model_on_local:
    num_samples = X1.shape[0]
    indices = np.arange(num_samples)
    train_indices = indices[:int((1- val_split)*num_samples)]
    val_indices = indices[int((1- val_split)*num_samples):]
    train_subsampler = torch.utils.data.Subset(dataset, train_indices)
    val_subsampler = torch.utils.data.Subset(dataset, val_indices)
    ldim = parameters['latent_dim']['value']
    model_name = f'{ae_name}_ldim{ldim}_{dataset_name}'
    train(train_subsampler, val_subsampler, model_type=model_type, run_dir=run_dir, sweep_hyperparams=False, save_model=save_model_on_local, project_name=project_name, model_name=model_name, config=parameters)
    # If any .npy files are found in the run directory transfer them into the model directory
    for file in os.listdir(f'../runs/{run_dir}'):
        if file.endswith('.npy'):
            shutil.move(f'../runs/{run_dir}/{file}', f'../runs/{run_dir}/{model_name}/{file}')
else:
    print('Operation not supported')

# Models for film scale perovskite dataset v2 trained using 90:10 split and batch size 10

# paths_seed10 = ['nthota2/perovskite_dataset_v2/i0siejd3',
#                 'nthota2/perovskite_dataset_v2/41lyo4zx',
#                 'nthota2/perovskite_dataset_v2/090ev6lt',
#                 'nthota2/perovskite_dataset_v2/4hx3g1ze',
#                 'nthota2/perovskite_dataset_v2/5lxjh9yz',
#                 'nthota2/perovskite_dataset_v2/plfhcevp',
#                 'nthota2/perovskite_dataset_v2/s2blmq1p',
#                 'nthota2/perovskite_dataset_v2/33kw47zq',
#                 'nthota2/perovskite_dataset_v2/51g6co9z',
#                 'nthota2/perovskite_dataset_v2/ea6e0vvg',
#                 'nthota2/perovskite_dataset_v2/s1xzk0v8']

# paths_seed1 = ['nthota2/perovskite_dataset_v2/661g95rw',
#                'nthota2/perovskite_dataset_v2/0g449pn8',
#                'nthota2/perovskite_dataset_v2/lqh9dg48',
#                'nthota2/perovskite_dataset_v2/e7dgdq6i',
#                'nthota2/perovskite_dataset_v2/26a8ng2u',
#                'nthota2/perovskite_dataset_v2/ks0zvbhp',
#                'nthota2/perovskite_dataset_v2/adx91d6o',
#                'nthota2/perovskite_dataset_v2/901war6q',
#                'nthota2/perovskite_dataset_v2/qwiryh15',
#                'nthota2/perovskite_dataset_v2/krs81m9u',
#                'nthota2/perovskite_dataset_v2/319183bd']

# paths_seed2 = ['nthota2/perovskite_dataset_v2/mdshovqj',
#                'nthota2/perovskite_dataset_v2/y3v0p2x1',
#                'nthota2/perovskite_dataset_v2/st18wn73',
#                'nthota2/perovskite_dataset_v2/t1ftdjr5',
#                'nthota2/perovskite_dataset_v2/041zbojc',
#                'nthota2/perovskite_dataset_v2/1v79q4r9',
#                'nthota2/perovskite_dataset_v2/00tu85tg',
#                'nthota2/perovskite_dataset_v2/4wx9bbt3',
#                'nthota2/perovskite_dataset_v2/2v4pf9yo',
#                'nthota2/perovskite_dataset_v2/jirnsqq2',
#                'nthota2/perovskite_dataset_v2/8bgcj3id']

# paths_seed3 = ['nthota2/perovskite_dataset_v2/9c91pcxl',
#                'nthota2/perovskite_dataset_v2/iisquapr',
#                'nthota2/perovskite_dataset_v2/f33cluje',
#                'nthota2/perovskite_dataset_v2/kmx635x0',
#                'nthota2/perovskite_dataset_v2/zcohntic',
#                'nthota2/perovskite_dataset_v2/nx650722',
#                'nthota2/perovskite_dataset_v2/yk3x5pn5',
#                'nthota2/perovskite_dataset_v2/lt3yut3v',
#                'nthota2/perovskite_dataset_v2/uwni8be0',
#                'nthota2/perovskite_dataset_v2/9cpliop8',
#                'nthota2/perovskite_dataset_v2/ptophhdq']

# paths_seed4 = ['nthota2/perovskite_dataset_v2/g5som797',
#                'nthota2/perovskite_dataset_v2/fgz3igxx',
#                'nthota2/perovskite_dataset_v2/l3q1w43r',
#                'nthota2/perovskite_dataset_v2/3shs46xm',
#                'nthota2/perovskite_dataset_v2/iscztzx0',
#                'nthota2/perovskite_dataset_v2/aeqbbo34',
#                'nthota2/perovskite_dataset_v2/hful22lx',
#                'nthota2/perovskite_dataset_v2/31bx1k6h',
#                'nthota2/perovskite_dataset_v2/b7lr6u9d',
#                'nthota2/perovskite_dataset_v2/9tv3phh1',
#                'nthota2/perovskite_dataset_v2/oioxfv5e']

In [1]:
print([2]*(3 + 1))

[2, 2, 2, 2]


## Model Interpretability

#### Plotting model performance vs latent dimension

In [None]:
# TODO : Store the below loss values in .npy arrays and move them to the model directory

# For RL paper
# # Grid data - UnsupervisedSimpleAE 1
# latent_space = [1, 2, 4, 6, 8]
# total_val_loss = [0.001025053068588022, 0.001063714546035044, 0.0011066086881328374, 0.0013437549932859838, 0.0009386805177200586]
# total_train_loss = [0.007932848995551467, 0.008279352798126638, 0.004237618821207434, 0.004374776501208544, 0.004906855116132647]

# # Grid data - Nonlinf5 SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6]
# total_val_loss = [0.7463283464312553, 0.2002287097275257, 0.1977713629603386, 0.1788929458707571]
# total_train_loss = [0.4705591835081578, 0.14468571869656444, 0.1250611103605479, 0.13816878804937005]

# # Grid data - Sumf5 - SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6]
# total_val_loss = [0.49156802892684937, 0.019491535145789385, 0.02250012196600437, 0.02030603913590312]
# total_train_loss = [0.4140172880142927, 0.03587072214577347, 0.02835194836370647, 0.03037486143875867]

# # Random data - UnsupervisedSimpleAE 1
# latent_space = [2, 4, 6, 8]
# total_val_loss = [0.6480237692594528, 0.4781555384397506, 0.24337586015462875, 0.002268550335429609]
# total_train_loss = [0.6661303304135799, 0.4441600125283003, 0.20833437889814377, 0.007954166503623128]

# # Random data - Nonlinf5 - SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6, 8, 10, 11, 12]
# total_val_loss = [0.7916551381349564, 0.6637141406536102, 0.5525364577770233, 0.3102440983057022, 0.19166426360607147, 0.051493472419679165, 0.021430929424241185, 0.007555923308245838]
# total_train_loss = [0.7504490427672863, 0.6266670003533363, 0.4413919039070606, 0.30615816451609135, 0.18379461765289307, 0.05851957411505282, 0.028894496499560773, 0.016192137030884624]

# # Random data - Sumf5 - SupervisedSimpleAE 2
# # Will have to change the predictor as only prediction loss is high, reconstruction is good.
# latent_space = [1, 2, 4, 6, 8, 10, 11, 12]
# total_val_loss = [1.4578111469745636, 1.4702374935150146, 1.383190006017685, 1.16259, 0.99866, 0.81431, 0.75846, 0.83324]
# total_train_loss = [1.4289479702711103, 1.2744147181510923, 1.0854754857718945, 0.96673, 0.84347, 0.72946, 0.69708, 0.66834]

# Arun2024 (Input dim = 15 + 4 = 19)
# latent_space =      [2,     4,     6,     8,     10,    12,    14]
# pred_val_loss =     [0.206, 0.178, 0.244, 0.227, 0.148, 0.233, 0.231]
# pred_train_loss =   [0.11,  0.076, 0.082, 0.067, 0.074, 0.071, 0.067]
# recont_val_loss =   [0.561, 0.069, 0.055, 0.050, 0.038, 0.046, 0.056]
# recont_train_loss = [0.234, 0.055, 0.037, 0.038, 0.037, 0.042, 0.039]
# total_val_loss = [0.76786, 0.24755343379718917, 	0.24418, 0.27646, 0.18622, 0.27943, 0.28719]
# total_train_loss = [0.38542, 0.16735203887741917, 	0.15132, 0.13772, 0.14288, 0.14528, 0.13803]

# for 10 this is the best that can be done. Adding more layers or increasing the hidden dim or using non linea act does not work.
# val : 0.7632711380720139
# train : 0.7363427169620991
run_dir = 'results_for_RL_paper'
train_loss_matrix_path = f'../runs/{run_dir}/SupSimpleAE_2_ldim10_randomData_nonlinf5/train_total_loss.npy'
val_loss_matrix_path = f'../runs/{run_dir}/SupSimpleAE_2_ldim10_randomData_nonlinf5/val_total_loss.npy'

train_loss_matrix = np.load(train_loss_matrix_path)
val_loss_matrix = np.load(val_loss_matrix_path)
train_mean_losses = np.mean(train_loss_matrix, axis=0)
train_std_losses = np.std(train_loss_matrix, axis=0)
val_mean_losses = np.mean(val_loss_matrix, axis=0)
val_std_losses = np.std(val_loss_matrix, axis=0)

latent_space_to_sweep = [2, 4, 6, 8, 10, 12]

plt.figure(figsize=(4, 3))

# Plot the latent space vs total validation loss
plt.errorbar(latent_space_to_sweep, train_mean_losses, yerr=train_std_losses, label='train', marker='o', linestyle='-')
plt.errorbar(latent_space_to_sweep, val_mean_losses, yerr=val_std_losses, label='val', marker='o', linestyle='-')
# # Label the mean values on top with red color
# for i, txt in enumerate(val_mean_losses):
#     # plt.annotate(f'{txt:.3f}', (latent_space_to_sweep[i], val_mean_losses[i]), textcoords="offset points", xytext=(0,15), ha='center', color='r')
#     print(f'{latent_space_to_sweep[i]} : Val Std dev. : {val_std_losses[i]}')  
#     print(f'{latent_space_to_sweep[i]} : Val Mean : {val_mean_losses[i]}')
# for i, txt in enumerate(train_mean_losses):
#     # plt.annotate(f'{txt:.3f}', (latent_space_to_sweep[i], train_mean_losses[i]), textcoords="offset points", xytext=(0,-15), ha='center', color='b')
#     print(f'{latent_space_to_sweep[i]} : Train Std dev. : {train_std_losses[i]}')  
#     print(f'{latent_space_to_sweep[i]} : Train Mean : {train_mean_losses[i]}')
# plt.plot(latent_space, total_train_loss, marker='o', linestyle='', label='train')
# plt.plot(latent_space, total_val_loss, marker='o', linestyle='', label='val')
plt.xlabel('Latent space dimension')
plt.xticks(latent_space_to_sweep)
# Set the tick labels
plt.ylabel('Total MAE (Reconst. + Pred.)')
plt.legend()

# Adjust the subplot
plt.tight_layout()

# Save the figure
plt.savefig('latent_space_vs_total_loss.pdf')

#### Loading the torch model

In [None]:
parameters = {
                'input_dim':{
                    'value':X1_final.shape[1]
                },
                'hidden_dim':{
                    'value':None
                },
                'latent_dim':{
                    'value':10
                },
                'y_dim':{
                    'value':0
                },
                'dropout':{
                    'value':0
                },
                'l1_reg':{
                    'value':0
                },
                'l2_reg':{
                    'value':0.001
                },
                'num_layers':{
                    'value':1
                },
                'activation_fn':{
                    'value':None
                },
                'dec_activation_fn':{
                    'value':None
                },
                'pred_activation_fn':{
                    'value':None
                },
                'batch_size':{
                    'value':10
                },
                'learning_rate':{
                    'value':1e-3
                },
                'epochs':{
                    'value':1500
                }
}

# Load torch model
ldim = parameters['latent_dim']['value']
# model_name = f'{ae_name}_ldim{ldim}_{dataset_name}'
# model_name = 'UnsupSimpleAE_3_ldim12_peroveff2_device_none_3_75_e2000'
# model_name = 'SupSimpleAE_2_ldim2_gridData_nonlinf5'
# model_name = 'best_SupSimpleAE_2_ldim2_gridData_sumf5'
model_name = 'SupSimpleAE_2_ldim10_randomData_nonlinf5'
run_dir = 'results_for_RL_paper'
model_type = 'SupervisedSimpleAE'
model_path = f'../runs/{run_dir}/{model_name}/' + model_name + '.pth'
if model_type == 'UnsupervisedSimpleAE':
    model = UnsupervisedSimpleAE(parameters['input_dim']['value'],
                                parameters['hidden_dim']['value'], 
                                parameters['dropout']['value'], 
                                parameters['latent_dim']['value'], 
                                parameters['num_layers']['value'], 
                                parameters['activation_fn']['value'],
                                parameters['dec_activation_fn']['value'])
elif model_type == 'SupervisedSimpleAE':
    model = SupervisedSimpleAE(parameters['input_dim']['value'],
                                parameters['hidden_dim']['value'], 
                                parameters['dropout']['value'], 
                                parameters['latent_dim']['value'], 
                                parameters['num_layers']['value'], 
                                parameters['activation_fn']['value'],
                                parameters['pred_activation_fn']['value'],
                                parameters['dec_activation_fn']['value'])
elif model_type == 'VAE':
    model = UnsupervisedVAE(parameters['input_dim']['value'],
                            parameters['hidden_dim']['value'], 
                            parameters['dropout']['value'], 
                            parameters['latent_dim']['value'], 
                            parameters['num_layers']['value'], 
                            parameters['activation_fn']['value'])
elif model_type == 'GMVAE':
    model = GMVAE(parameters['input_dim']['value'],
                parameters['y_dim']['value'],
                parameters['hidden_dim']['value'], 
                parameters['dropout']['value'], 
                parameters['latent_dim']['value'], 
                parameters['num_layers']['value'], 
                parameters['activation_fn']['value'])
else:
    pass
# print(torch.load(model_path))

model.load_state_dict(torch.load(model_path))

#### Save the latents to .csv file

In [None]:
# Save the latents to .csv file
# z, pred, _, _ = model(torch.cat((torch.from_numpy(X1_final), torch.from_numpy(X2_final)), dim=1))
z, _, _, _, _ = model(torch.concat((torch.from_numpy(X1_final), torch.from_numpy(X2_final), torch.from_numpy(X3_final), torch.from_numpy(X4_final)), dim=1))
latent_filename = f'latents_from_{dataset_name}'

# Data folder paths and file names
dataset_folder_name = 'synthetic_dataset'
dataset_file_name = 'synthetic_data_gridSamples_200_sumf5.csv'
new_dataset_file_name = 'synthetic_data_gridSamples_200_sumf5_with_ae1_latents_concat.csv'
run_folder_name = 'perovskite_multiscale_dataset_v2'
model_folder_name = 'best_SupSimpleAE_1_ldim4_arun2024'
AE_number = '1'

dataset_folder = '../datasets/' + dataset_folder_name
dataset_file = dataset_folder + '/' + dataset_file_name
concatenated_dataset_file = dataset_folder + '/' + new_dataset_file_name

run_folder = '../runs/' + run_dir
latent_file = run_folder + '/' + model_name + '/' + latent_filename + '.csv'

# Save the latents to .csv file
z_df = pd.DataFrame(z.detach().numpy())
# pred_df = pd.DataFrame(pred.detach().numpy())
# Conccaetnate the latents and the predictions
# z_pred_df = pd.concat([z_df, pred_df], axis=1)
# z_pred_df = pd.DataFrame(z.detach().numpy())
z_df.to_csv(latent_file, index=False, header=False) 

# latents = pd.read_csv(latent_file, header=None, skiprows=None)
# data = pd.read_csv(dataset_file)

# for i in range(len(latents.columns)):
#     data['AE'+ AE_number +'_latent_'+str(i)] = latents[i]

# data.to_csv(concatenated_dataset_file, index=False)

In [None]:
# # Find which 'z'is closest to the 'z' provided below
# z_query = torch.tensor([0.26077228, 3.95919058, 0., 0., 0., 0., 1.19784881, 1.33860231], dtype=torch.float32)
# z_query = z_query.unsqueeze(dim=0)
# z_query = z_query.repeat(z.shape[0], 1)
# dist = torch.nn.PairwiseDistance(p=1)
# distances = dist(z, z_query)
# closest_idx = torch.argmin(distances).item()
# print(torch.min(distances))
# print(f'Closest z to the query z is at index : {closest_idx}')
# print(f'Closest z to the query z is : {z[closest_idx]}')
# print(f'Reconst for query z is : {reconst[closest_idx]}')
# print(f'INvert scaling for reconst : {scaler_X.inverse_transform(reconst[closest_idx].detach().numpy().reshape(1, -1))}')

In [None]:
input = torch.tensor(X_scaled_np32[1, :], requires_grad=True)
print(input)
baseline = torch.zeros_like(input)
print(baseline)
print(model(baseline))

#### Feature Importance

##### 1. Using Integrated Gradients

In [None]:
from captum.attr import IntegratedGradients, DeepLift, InputXGradient, Saliency
import seaborn as sns


def model_wrapper(input):
    z, pred, reconst = model(input)
    # z, pred, reconst1, reconst2 = model(input)
    # z, recon1, recon2, recon3, recon4 = model(input)
    return pred

# Create an instance of the IntegratedGradients class
sal = Saliency(model_wrapper)
ig = IntegratedGradients(model_wrapper)
ixg = InputXGradient(model_wrapper)
dl = DeepLift(model_wrapper)

input = torch.from_numpy(X1_final)
# input = torch.cat((torch.from_numpy(X1_scaled_np32), torch.from_numpy(X2_np32)), dim=1)
# input = torch.concat((torch.from_numpy(X1_final), torch.from_numpy(X2_final), torch.from_numpy(X3_final), torch.from_numpy(X4_final)), dim=1)

# attr_sal = sal.attribute(input, target=None)
# attr_sal_np = attr_sal.detach().numpy()
# attr_ixg = ig.attribute(input, baselines=0, target=None)
# attr_ixg_np = attr_ixg.detach().numpy()
attr_ig = ig.attribute(input, baselines=0, target=None)
attr_ig_np = attr_ig.detach().numpy()
# attr_dl = dl.attribute(input, baselines=0, target=None)
# attr_dl_np = attr_dl.detach().numpy()

fig, ax = plt.subplots(figsize=(8, 5))
# # Plot the attributions
using_boxplot = True
if using_boxplot:
    # Label the means on top of the bar in the boxplot
    for i in range(attr_ig_np.shape[1]):
        # plt.text(i+1, 1.5, '{:0.1f}'.format(np.mean(attr_ig_np[:, i])), ha='center', va='bottom')
        # Plot the standard deviation
        plt.text(i+1, 1.6, '{:0.1f}'.format(np.std(attr_ig_np[:, i])), ha='center', va='bottom')
    # Plot the means and standard deviations of the attributions for each feature
    plt.boxplot(attr_ig_np, showmeans=True, meanline=True)
    ax.set_xticklabels(descriptors1, rotation=90)
    plt.title('Integrated Gradients calc. wrt pred')
    plt.show()
else:
    means = np.mean(attr_ig_np, axis=0)
    std_dev = np.std(attr_ig_np, axis=0)
    num_feats = attr_ig_np.shape[1]
    # ax.bar(np.arange(num_feats), means, yerr=std_dev, align='center', alpha=0.5, ecolor='black', capsize=10)
    ax.scatter(np.arange(num_feats), means, label='mean', color='r', marker='.')
    ax.errorbar(np.arange(num_feats), means, yerr=std_dev, fmt='o', capsize=5)
    for i in range(attr_ig_np.shape[1]):
        # plt.text(i+1, 1.5, '{:0.1f}'.format(np.mean(attr_ig_np[:, i])), ha='center', va='bottom')
        # Plot the standard deviation
        plt.text(i, 0.8, '{:0.1f}'.format(np.std(attr_ig_np[:, i])), ha='center', va='bottom')
    ax.set_xticks(np.arange(num_feats))
    # descriptors = ['DMF', 'DMF; DMSO', 'DMSO', 'DMSO; GBL', 
    #                '1', '1; 1', '1; 4', '2; 1', '3; 1', '3; 7', '4; 1', '7; 3',
    #                '8; 1', '9; 1', 'Anisole', 'Chlorobenzene', 'Diethyl ether', 
    #                'Ethyl acetate', 'N2', 'Toluene', 
    #                'AE1_latent_0', 'AE1_latent_1', 'AE1_latent_2', 'AE1_latent_3']
    descriptors = ['PET | ITO', 'SLG | FTO', 'SLG | ITO', 
                   'NiO-c', 'PEDOT:PSS', 'PTAA', 'Spiro-MeOTAD',
                   'C60 | BCP', 'PCBM-60', 'PCBM-60 | BCP', 'PCBM-60 | ZnO-np',
                   'SnO2-c', 'SnO2-np', 'TiO2-c', 'TiO2-c | TiO2-mp',
                   'AE2_latent_0', 'AE2_latent_1', 'AE2_latent_2', 'AE2_latent_3',
                   'AE2_latent_4', 'AE2_latent_5', 'AE2_latent_6', 'AE2_latent_7']
    ax.set_xticklabels(descriptors, rotation=90)
    plt.tight_layout()
    plt.show() 

# # Matrix Plot model weights
# for name, param in model.encoder.named_parameters():
#     if name == 'layers.0.weight':
#         ax = plt.figure(figsize=(9, 8))
#         ax = sns.heatmap(param.detach().numpy(), annot=True, fmt='.3f', cmap='coolwarm')
#         # Remove y axis labels
#         ax.yticks([])
#         ax.set_title('Encoder Layer 1 Weights')
    

##### 2. Pearson Correlation Coeffs

In [None]:
def model_wrapper(input):
    z, pred, reconst = model(input)
    # z, pred, reconst1, reconst2 = model(input)
    # z, recon1, recon2, recon3, recon4 = model(input)
    return pred


pred = model_wrapper(torch.from_numpy(X1_final))
corrcoefs = np.corrcoef(X1_final, y=pred.detach().numpy(), rowvar=False)

# In the corr coefs change all values that lie between +- threhold to 0
threshold = 0.1
corrcoefs[np.abs(corrcoefs) < threshold] = 0

# Make a color bar with 0 value as dark gray
# cmap = sns.diverging_palette(220, 20, as_cmap=True)
# cmap.set_bad(color='black')
import matplotlib.colors as colors
cmap = colors.LinearSegmentedColormap.from_list('mycmap', ['blue', 'gray', 'red'])
norm = colors.Normalize(vmin=-1, vmax=1)

# Matrix plot of corr coefs 
# plt.figure(figsize=(12,10))
# Set figure size in inches
plt.figure(figsize=(5, 3.85))

sns.heatmap(corrcoefs, annot=False, cmap=cmap, norm=norm, fmt='.2f', annot_kws={"color": 'black', 'size':1})
plt.tick_params(axis='x', labelbottom=False, labeltop=True)

plt.gca().xaxis.tick_top()

plt.xticks(np.arange(X1_final.shape[1] + 1) + 0.5, labels=descriptors1 + ['f5_pred'], rotation=90)
plt.yticks(np.arange(X1_final.shape[1] + 1) + 0.5, labels=descriptors1 + ['f5_pred'], rotation=0)

#Adjust subplot size
plt.tight_layout()

# Save the plot to pdf
plt.savefig('corr_coefs.pdf')

##### 3. Using kernel SHAP

In [None]:
from nn.vanilla_ae import VanillaAE
from utils.custom_utils import read_from_pickle
from inputs_perov_data.nn_inputs import list_of_nn_params_dict
from inputs_perov_data.train_inputs import list_of_nn_train_params_dict
from inputs_perov_data.dataset_inputs import list_of_nn_datasets_dict

ae = 1

ae_index = int(ae) - 1

run_dir = '../runs/perovskite_multiscale_dataset_3'

# # Read the ae params, train params and datasets
# list_of_nn_params_dict = read_from_pickle('list_of_nn_params_dict.pkl', run_dir)
# list_of_nn_train_params_dict = read_from_pickle('list_of_nn_train_params_dict.pkl', run_dir)
# list_of_nn_datasets_dict = read_from_pickle('list_of_nn_datasets_dict.pkl', run_dir)

nn_params_dict = list_of_nn_params_dict[ae_index]
nn_train_params_dict = list_of_nn_train_params_dict[ae_index]
nn_datasets_dict = list_of_nn_datasets_dict[ae_index]

global_seed = nn_train_params_dict['global_seed']

nn_save_dir = run_dir + '/' + nn_params_dict['model_type']
print(nn_save_dir)

ae = VanillaAE(nn_save_dir,
                nn_params_dict,
                nn_train_params_dict,
                nn_datasets_dict)

# Load model from checkpoint 
checkpoint_path = nn_save_dir + '/checkpoints/last.ckpt'
checkpoint = torch.load(checkpoint_path)
ae.load_from_checkpoint(checkpoint_path)

# # Load the training dataset
# train_dataset_path = nn_save_dir + '/datasets/train_dataset.pt'
# train_dataset = torch.load(train_dataset_path)
# train_dataset_np = train_dataset[:]['all_props'].detach().numpy()

# predict_dataset_path = nn_save_dir + '/datasets/predict_dataset.pt'
# predict_dataset = torch.load(predict_dataset_path)

In [None]:
# Combine all the inputs into one dataset, if only one input then skip
combined = False
background_dataset_gen = 'all_samples' # 'kmeans' or 'all_samples'
descriptors1 = ['A_ion_rad', 'A_at_wt', 'A_EA', 'A_IE', 'A_En',
                'B_ion_rad', 'B_at_wt', 'B_EA', 'B_IE', 'B_En',
                'X_ion_rad', 'X_at_wt', 'X_EA', 'X_IE', 'X_En']
# descriptors2=['AE1_l0', 'AE1_l1', 'AE1_l2', 'AE1_l3', 'AE1_l4', 'AE1_l5', 
#                  'AE1_l6', 'AE1_l7', 'AE1_l8', 'AE1_l9', 'AE1_l10', 'AE1_l11']+\
#                 list(ae.datasets.variable_preprocessors['etm'].categories_[0]) +\
#                 list(ae.datasets.variable_preprocessors['htm'].categories_[0])

if combined:    
    dataset = torch.concat((ae.all_samples['latents'], ae.all_samples['etm'], ae.all_samples['htm']), dim=1)
else:
    dataset = ae.all_samples['all_props']

# Using kmeans to cluster the data
from sklearn.cluster import KMeans

if background_dataset_gen == 'kmeans':
    krange = np.arange(1, 200)
    inertia = np.zeros(len(krange))

    for k in krange:
        kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(dataset)
        if k == krange[-1]:
            background_dataset = kmeans.cluster_centers_
            # Change dtypes to np.float32
            background_dataset = background_dataset.astype(np.float32)
        inertia[k-1] = kmeans.inertia_
    
    # Plot elbow plot
    plt.plot(krange, inertia, marker='o')
    plt.xlabel('Number of clusters')
    plt.ylabel('Inertia')
elif background_dataset_gen == 'all_samples':
    # X100 = shap.utils.sample(combined_dataset.detach().numpy(), 100)
    background_dataset = shap.utils.sample(ae.all_samples['all_props'].detach().numpy(), len(ae.all_samples['all_props']))

In [None]:
class modelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(modelWrapper, self).__init__()
        self.model = model
    def forward(self, input):
        # REWRITE BASED ON INPUTS FED TO MODEL !!
        # input = {'latents':torch.from_numpy(input[:,:12]),
        #          'etm':torch.from_numpy(input[:,12:19]),
        #          'htm':torch.from_numpy(input[:,19:])} 
        input = {'all_props':torch.from_numpy(input)}
        submodule_outputs = self.model(input)
        #submodule_outputs_from_loaded = ae(ae.all_samples)
        pred = submodule_outputs['bg_pred'].detach().numpy()
        # print(f'This is the pred : {pred}')
        return pred

model_wrapper = modelWrapper(ae)
kernel_explainer = shap.KernelExplainer(model_wrapper, data=background_dataset, link='identity', feature_names=descriptors1)
shap_values = kernel_explainer.shap_values(dataset.detach().numpy())
shap_values = np.squeeze(shap_values, axis=-1)
shap.initjs()
                
shap.summary_plot(shap_values, features=dataset.detach().numpy(), feature_names=descriptors1, max_display=10, plot_type='dot', show=False, plot_size=(5.8, 3.5))

In [None]:
# Use SHAP force plot
# shap.force_plot(kernel_explainer.expected_value, shap_values[34], features=ae.all_samples['all_props'].detach().numpy()[34], feature_names=descriptors1)
shap.plots.force(kernel_explainer.expected_value, shap_values[34], features=ae.all_samples['all_props'].detach().numpy()[34], feature_names=descriptors1)

In [None]:
import shap

# Sample 100 points to use as the background dataset
X100 = shap.utils.sample(X1_final, 100)

class modelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(modelWrapper, self).__init__()
        self.model = model
    def forward(self, input):
        # For kenrnal explainer
        input = torch.from_numpy(input)
        z, pred, reconst = self.model(input)
        pred = pred.detach().numpy()
        return pred
    
model_wrapper = modelWrapper(model)

# explainer = shap.Explainer(model_wrapper, X100)
kernel_explainer = shap.KernelExplainer(model=model_wrapper, data=X100)
# gradient_explainer = shap.GradientExplainer(model_wrapper, torch.from_numpy(X100))
# deep_explainer = shap.DeepExplainer(model_wrapper, torch.from_numpy(X100))

# Shaplye values for all the points in the dataset
shap_values = kernel_explainer.shap_values(X1_final)
# shap_values = explainer(X100)
# shap_values = gradient_explainer.shap_values(torch.from_numpy(X100))
# shap_values = deep_explainer.shap_values(torch.from_numpy(X100))
shap_values = np.squeeze(shap_values, axis=-1)

shap.initjs()

# Save the SHAP summary plot to pdf

# For displaying for multiple samples. show=False is required to save the plot as a matplotlib figure
shap.summary_plot(shap_values, features=X1_final, feature_names=descriptors1, max_display=10, plot_type='dot', show=False, plot_size=(5.8, 3.5))

plt.savefig('SHAP_summary_plot.pdf')

# Explaining a single prediction
# shap.waterfall_plot(shap.Explanation(values=shap_values[1], base_values=0, data=X1_final[1], feature_names=descriptors1))

# Requires an explanation object to be passed
# shap.plots.bar(shap.Explanation(values=shap_values[:], base_values=0, data=X1_final[:], feature_names=descriptors1))
# shap.force_plot(kernel_explainer.expected_value, shap_values, descriptors1)



In [None]:
# implementing own version of integrated gradients
input = torch.tensor(X_scaled_np32[1, :], requires_grad=True)
print(input)
baseline = torch.zeros_like(input)
print(baseline)
def interpolated_features(num_steps):
    alphas = torch.linspace(0, 1, num_steps+1)
    delta = input - baseline
    return torch.stack([baseline + alpha*delta for alpha in alphas])

def compute_gradients(interpolated_feats):
    grads = []
    for i in range(interpolated_feats.shape[0]):
        input = interpolated_feats[i]
        input = input.unsqueeze(dim=0)
        z, pred, reconst = model(input)
        print(pred.squeeze(dim=0))
        pred.backward()
        grads.append(input.grad)
        pred = model(interpolated_feats[i].unsqueeze(dim=0))
        pred.backward()
        grads.append(input.grad)
    return torch.stack(grads)

computed_grads = compute_gradients(interpolated_features(10))
plt.plot(torch.linspace(0, 1, 11), computed_grads)


#### 2D plots of latent space

In [None]:
# z, pred, reconst, mu, logvar = model(torch.from_numpy(elemental_properties))
# z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(torch.from_numpy(elemental_properties))

# # Only for 2D latent plotting
# plt.scatter(z[:, 0].detach().numpy(), z[:, 1].detach().numpy(), c=pred.detach().numpy(), cmap='viridis', s=10, alpha=0.5)
# plt.colorbar()
# plt.show()

### Observations for Unit normal prior
- What we observe from the above example is that although multivariate Gaussian distribution are useful
    as each dimension can encode a separate DOF which results in representations that are sturctured and disentangled, 
    they are unimodal and hence cannot encode complex representations. A natural extension is to then use a different
    prior. Gaussain Mixture Model (GMM) is the next choice.
- Latent space is segregated into different classes.
- However, inference is non-trivial.