In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import wandb
import os
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import copy
import joblib

# Login to wandb. Create a wandb account and get the api key from the user settings tab
user_name = 'nthota2'
project_name = 'perovskite_dataset_v2'
os.environ["WANDB_NOTEBOOK_NAME"] = 'generative_models.ipynb'
os.environ["WANDB_API_KEY"] = "2ebe7f940d7b84cdc3be6588851e30d5c34d201d"
# dryrun = Does not store any weights and bias data locally
# online = enables cloud syncing
os.environ["WANDB_MODE"] = "online"

seed = 10
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(0)

## Load the dataset

In [2]:
## ------------------------ USER INPUT START ------------------------

dataset1_to_train = 'PSC_eff_v2'
dataset2_to_train = 'latents_from_SupAE1'

# For RL dataset
# descriptors = ['z1', 'z2', 'z3', 'z4', 
#                'z5', 'z6', 'z7', 'z8']
# descriptors = ['f1', 'f2', 'f3', 'f4', 'AE1_latent_0']
# descriptors = ['f1', 'f2', 'f3', 'f4', 
#                'AE1_latent_0', 'AE1_latent_1', 'AE1_latent_2', 'AE1_latent_3', 
#                'AE1_latent_4', 'AE1_latent_5', 'AE1_latent_6', 'AE1_latent_7']
# For perovskite dataset v2
# descriptors1 = ['A_ion_rad', 'A_at_wt', 'A_EA', 'A_IE', 'A_En',
#                 'B_ion_rad', 'B_at_wt', 'B_EA', 'B_IE', 'B_En',
#                 'X_ion_rad', 'X_at_wt', 'X_EA', 'X_IE', 'X_En']
descriptors1 = ['Perovskite_deposition_solvents'] 
descriptors2 = ['Perovskite_deposition_solvents_mixing_ratios']
descriptors3 = ['Perovskite_deposition_quenching_media']
descriptors4 = ['AE1_latent_0', 'AE1_latent_1', 'AE1_latent_2', 'AE1_latent_3']
# descriptors2 = ['Cubic', 'Tetra', 'Ortho', 'Hex']

standardize1 = False
ohe1 = True
load_scaler1 = False
scaler1_loc = None

standardize2 = False
ohe2 = True
load_scaler2 = False
scaler2_loc = None

standardize3 = False
ohe3 = True
load_scaler3 = False
scaler3_loc = None

standardize4 = True
ohe4 = False
load_scaler4 = False
scaler4_loc = None

label = None
standardizelabel = False

# Dictionary containing paths to datasets
dataset_path_dict = {
    'gridSamples_200_nonlinf5': '../datasets/synthetic_dataset/synthetic_data_gridSamples_200_with_ae1_latents_concat.csv',
    'gridSamples_200_sumf5': '../datasets/synthetic_dataset/synthetic_data_gridSamples_200_sumf5_with_ae1_latents_concat.csv',
    'randomSamples_200_nonlinf5': '../datasets/synthetic_dataset/synthetic_data_randomSamples_200_with_ae1_latents_concat.csv',
    'randomSamples_200_sumf5': '../datasets/synthetic_dataset/synthetic_data_randomSamples_200_sumf5_with_ae1_latents_concat.csv',
    'PSC_bandgaps_v1': '../datasets/PSC_bandgaps/PSC_bandgaps_dataset.csv',
    'PSC_eff_v1': '../datasets/PSC_efficiencies/PSC_efficiencies_dataset.csv',
    'PSC_eff_v2': '../datasets/PSC_efficiencies/PSC_efficiencies_dataset_2.csv',
    'HSE_arun2022': '../datasets/PSC_bandgaps/HSE_data_arun2022.csv',
    'HSE_arun2024': '../datasets/PSC_bandgaps/HSE_data_arun2024.csv',
    'latents_from_SupAE1':'../runs/perovskite_multiscale_dataset_v2/best_SupSimpleAE_1_ldim4_arun2024/latents_from_PSC_efficiencies_dataset_2.csv',
}

data1 = pd.read_csv(dataset_path_dict[dataset1_to_train])
X1 = pd.DataFrame(data1, columns=descriptors1)
# Check if the dataset has any nan values
print(X1.isnull().values.any())
print(X1.shape)

if descriptors2 is not None:
    data2 = pd.read_csv(dataset_path_dict[dataset1_to_train])
    X2 = pd.DataFrame(data2, columns=descriptors2)
    print(X2.shape)

if descriptors3 is not None:
    data3 = pd.read_csv(dataset_path_dict[dataset1_to_train])
    X3 = pd.DataFrame(data3, columns=descriptors3)
    print(X3.shape)

if descriptors4 is not None:
    data4 = pd.read_csv(dataset_path_dict[dataset2_to_train])
    X4 = pd.DataFrame(data4, columns=descriptors4)
    print(X4.shape)

if label:
    y_copy = copy.deepcopy(data1[label].to_numpy())
    y = y_copy.reshape(-1,1)
    print(y.shape)

## ------------------------ USER INPUT END ------------------------

# Creating a pytorch dataset
class XsandYDataset(torch.utils.data.Dataset):
    def __init__(self, y, X1, X2):
        self.y = y
        self.X1 = X1
        self.X2 = X2
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.y[idx], self.X1[idx], self.X2[idx]

class XandYDataset(torch.utils.data.Dataset):
    def __init__(self, y, X):
        self.y = y
        self.X = X
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.y[idx], self.X[idx]
    
class XsDataset(torch.utils.data.Dataset):
    def __init__(self, X1, X2, X3, X4):
        self.X1 = X1
        self.X2 = X2
        self.X3 = X3
        self.X4 = X4
    def __len__(self):
        return len(self.X1)
    def __getitem__(self, idx):
        return self.X1[idx], self.X2[idx], self.X3[idx], self.X4[idx]

class XDataset(torch.utils.data.Dataset):
    def __init__(self, X):
        self.X = X
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx]
    
# Standardize the dataset
if standardize1:
    if load_scaler1:
        scaler_X1 = joblib.load(scaler1_loc)
    else:
        scaler_X1 = StandardScaler().fit(X1)
    X1_scaled_np = scaler_X1.transform(X1)
    X1_np = scaler_X1.inverse_transform(X1_scaled_np)
    X1_final = X1_scaled_np.astype(np.float32)
    print(f'X1 Scaling done accurately ? : {np.allclose(X1_np, X1.to_numpy())}')
elif ohe1:
    ohe_X1 = OneHotEncoder(sparse_output=False).fit(X1)
    X1_final = ohe_X1.transform(X1).astype(np.float32)
    # print(ohe_X1.inverse_transform(X1_final))
    print(f'OHE Categories : {ohe_X1.categories_}')
else:
    X1_final = X1.to_numpy().astype(np.float32)
print(f'X1_final shape : {X1_final.shape}')

if standardize2:
    if load_scaler2:
        scaler_X2 = joblib.load(scaler2_loc)
    else:
        scaler_X2 = StandardScaler().fit(X2)
    X2_scaled_np = scaler_X2.transform(X2)
    X2_np = scaler_X2.inverse_transform(X2_scaled_np)
    X2_final = X2_scaled_np.astype(np.float32)
    print(f'X2 Scaling done accurately ? : {np.allclose(X2_np, X2.to_numpy())}')
elif ohe2:
    ohe_X2 = OneHotEncoder(sparse_output=False).fit(X2)
    X2_final = ohe_X2.transform(X2).astype(np.float32)
    print(f'OHE Categories : {ohe_X2.categories_}')
else:
    X2_final = X2.to_numpy().astype(np.float32)
print(f'X2_final shape : {X2_final.shape}')

if standardize3:
    if load_scaler3:
        scaler_X3 = joblib.load(scaler3_loc)
    else:
        scaler_X3 = StandardScaler().fit(X3)
    X3_scaled_np = scaler_X3.transform(X3)
    X3_np = scaler_X3.inverse_transform(X3_scaled_np)
    X3_final = X3_scaled_np.astype(np.float32)
    print(f'X3 Scaling done accurately ? : {np.allclose(X3_np, X3.to_numpy())}')
elif ohe3:
    ohe_X3 = OneHotEncoder(sparse_output=False).fit(X3)
    X3_final = ohe_X3.transform(X3).astype(np.float32)
    print(f'OHE Categories : {ohe_X3.categories_}')
else:
    X3_final = X3.to_numpy().astype(np.float32)
print(f'X3_final shape : {X3_final.shape}')

if standardize4:
    if load_scaler4:
        scaler_X4 = joblib.load(scaler4_loc)
    else:
        scaler_X4 = StandardScaler().fit(X4)
    X4_scaled_np = scaler_X4.transform(X4)
    X4_np = scaler_X4.inverse_transform(X4_scaled_np)
    X4_final = X4_scaled_np.astype(np.float32)
    print(f'X4 Scaling done accurately ? : {np.allclose(X4_np, X4.to_numpy())}')
elif ohe4:
    ohe_X4 = OneHotEncoder(sparse_output=False).fit(X4)
    X4_final = ohe_X4.transform(X4).astype(np.float32)
    print(f'OHE Categories : {ohe_X4.categories_}')
else:
    X4_final = X4.to_numpy().astype(np.float32)
print(f'X4_final shape : {X4_final.shape}')

if label is not None:
    if standardizelabel:
        scaler_y = StandardScaler().fit(y)
        y_scaled_np = scaler_y.transform(y)
        y_np = scaler_y.inverse_transform(y_scaled_np)
        y_final = y_scaled_np.astype(np.float32)
        print(f'y Scaling done accurately ? : {np.allclose(y_np, y)}')
    else:
        y_final = y.astype(np.float32)

if label is not None:
    if descriptors2 is not None:
        dataset = XsandYDataset(torch.from_numpy(y_final), torch.from_numpy(X1_final), torch.from_numpy(X2_final))
    else:
        dataset = XandYDataset(torch.from_numpy(y_final), torch.from_numpy(X1_final))
else:
    if descriptors2 is not None:
        dataset = XsDataset(torch.from_numpy(X1_final), torch.from_numpy(X2_final), torch.from_numpy(X3_final), torch.from_numpy(X4_final))
    else:
        dataset = XDataset(torch.from_numpy(X1_final))

print(len(dataset))

False
(2245, 1)
(2245, 1)
(2245, 1)
(2245, 4)
OHE Categories : [array(['DMF', 'DMF; DMSO', 'DMSO', 'DMSO; GBL', 'GBL'], dtype=object)]
X1_final shape : (2245, 5)
OHE Categories : [array(['1', '3; 7', '4; 1', '7; 3', '9; 1'], dtype=object)]
X2_final shape : (2245, 5)
OHE Categories : [array(['Chlorobenzene', 'Diethyl ether', 'Ethyl acetate', 'Toluene',
       'Unknown'], dtype=object)]
X3_final shape : (2245, 5)
X4 Scaling done accurately ? : True
X4_final shape : (2245, 4)
2245


## Models

In [3]:
## ------------------------ USER INPUT START ------------------------
activation_fns = {'relu': torch.nn.ReLU(), 
                  'elu': torch.nn.ELU(),
                  'tanh': torch.nn.Tanh(), 
                  'sigmoid': torch.nn.Sigmoid(),
                  'softmax': torch.nn.Softmax(dim=1),
                   None: None}

custom_arch = False
# If true mention the feature dimensions for each input in the decoder.
multiple_outputs = True
## ------------------------ USER INPUT END ------------------------

class Encoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        try:
            self.activation_fn = activation_fns[activation_fn]
        except KeyError:
            raise ValueError('Invalid activation function')

        self.layers = torch.nn.ModuleList()

        if self.num_layers == 1:
            self.layers.append(torch.nn.Linear(self.input_dim, self.latent_dim))
            if self.activation_fn is not None:
                self.layers.append(self.activation_fn)
        else:
            for i in range(self.num_layers):
                if i == 0:
                    self.layers.append(torch.nn.Linear(self.input_dim, self.hidden_dim))
                    if self.activation_fn is not None:
                        self.layers.append(self.activation_fn)
                    self.layers.append(torch.nn.Dropout(self.dropout))
                elif i == self.num_layers - 1:
                    self.layers.append(torch.nn.Linear(self.hidden_dim, self.latent_dim))
                else:
                    self.layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                    if self.activation_fn is not None:
                        self.layers.append(self.activation_fn)
                    self.layers.append(torch.nn.Dropout(self.dropout))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i == 0:
                z = layer(x)
            else:
                z = layer(z)
        return z
    
class Predictor(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, output_activation_fn):
        super(Predictor, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        try:
            self.activation_fn = activation_fns[activation_fn]
            self.output_activation_fn = activation_fns[output_activation_fn]
        except KeyError:
            raise ValueError('Invalid activation function')
        self.custom_arch = False
    
        self.layers = torch.nn.ModuleList()
        
        if custom_arch:
            self.layers.append(torch.nn.Linear(self.latent_dim, self.hidden_dim))
            self.layers.append(torch.nn.Tanh())
            self.layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
            self.layers.append(torch.nn.Tanh())
            self.layers.append(torch.nn.Linear(self.hidden_dim, 1))
        else:
            if self.num_layers == 1:
                self.layers.append(torch.nn.Linear(self.latent_dim, 1))
                if self.output_activation_fn is not None:
                    self.layers.append(self.output_activation_fn)
            else:
                for i in range(self.num_layers):
                    if i == 0:
                        self.layers.append(torch.nn.Linear(self.latent_dim, self.hidden_dim))
                        if self.activation_fn is not None:
                            self.layers.append(self.activation_fn)
                        self.layers.append(torch.nn.Dropout(self.dropout))
                    elif i == self.num_layers - 1:
                        self.layers.append(torch.nn.Linear(self.hidden_dim, 1))
                        if self.output_activation_fn is not None:
                            self.layers.append(self.output_activation_fn)
                    else:
                        self.layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                        if self.activation_fn is not None:
                            self.layers.append(self.activation_fn)
                        self.layers.append(torch.nn.Dropout(self.dropout))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i == 0:
                pred = layer(x)
            else:
                pred = layer(pred)
        return pred
    
class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, output_activation_fn):
        super(Decoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        try:
            self.activation_fn = activation_fns[activation_fn]
            self.output_activation_fn = activation_fns[output_activation_fn]
        except KeyError:
            raise ValueError('Invalid activation function')
        self.multiple_outputs = True
        
        self.layers = torch.nn.ModuleList()

        if self.num_layers == 1:
            if self.multiple_outputs:
                self.layers.append(torch.nn.Linear(self.latent_dim, self.hidden_dim))
                if self.output_activation_fn is not None:
                    self.layers.append(self.output_activation_fn)
            else:
                self.layers.append(torch.nn.Linear(self.latent_dim, self.input_dim))
                if self.output_activation_fn is not None:
                    self.layers.append(self.output_activation_fn)

        else:
            for i in range(self.num_layers):
                if i == 0:
                    self.layers.append(torch.nn.Linear(self.latent_dim, self.hidden_dim))
                    if self.activation_fn is not None:
                        self.layers.append(self.activation_fn)
                    self.layers.append(torch.nn.Dropout(self.dropout))
                elif i == self.num_layers - 1:
                    if self.multiple_outputs:
                        self.layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                        if self.output_activation_fn is not None:
                            self.layers.append(self.output_activation_fn)
                    else:
                        self.layers.append(torch.nn.Linear(self.hidden_dim, self.input_dim))
                        if self.output_activation_fn is not None:
                            self.layers.append(self.output_activation_fn)
                else:
                    self.layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                    if self.activation_fn is not None:
                        self.layers.append(self.activation_fn)
                    self.layers.append(torch.nn.Dropout(self.dropout))

        if multiple_outputs:
            self.output1_layer = torch.nn.Linear(self.hidden_dim, 5)
            self.output2_layer = torch.nn.Linear(self.hidden_dim, 5)  
            self.output3_layer = torch.nn.Linear(self.hidden_dim, 5)      
            self.output4_layer = torch.nn.Linear(self.hidden_dim, 4)
            
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            if i == 0:
                recon = layer(x)
            else:
                recon = layer(recon)
        if multiple_outputs:
            recon1 = self.output1_layer(recon)
            recon2 = self.output2_layer(recon)
            recon3 = self.output3_layer(recon)
            recon4 = self.output4_layer(recon)
            return recon1, recon2, recon3, recon4
        else:
            return recon

class SupervisedSimpleAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn, dec_activation_fn):
        super(SupervisedSimpleAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
        self.predictor = Predictor(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn)
        self.decoder = Decoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
        # If multiple outputs then define the layers here ...

    def forward(self, x):
        z = self.encoder(x)
        pred = self.predictor(z)
        if multiple_outputs:
            recon1, recon2 = self.decoder(z)
            return z, pred, recon1, recon2
        else:
            recon = self.decoder(z)
            return z, pred, recon
    
class UnsupervisedSimpleAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn):
        super(UnsupervisedSimpleAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
        self.decoder = Decoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
        # If multiple outputs then define the layers here ...

    def forward(self, x):
        z = self.encoder(x)
        if multiple_outputs:
            recon1, recon2, recon3, recon4  = self.decoder(z)
            return z, recon1, recon2, recon3, recon4 
        else:
            recon = self.decoder(z)
            return z, recon
    
class SupervisedVAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn, dec_activation_fn):
        super(SupervisedVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
        self.predictor = Predictor(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn)
        self.decoder = Decoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
        self.mu = torch.nn.Linear(hidden_dim, latent_dim)
        self.logvar = torch.nn.Linear(hidden_dim, latent_dim)
    
    def reparameterize(self, mu, logvar):
        std = torch.sqrt(torch.exp(logvar))
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        pred = self.predictor(z)
        reconst = self.decoder(z)
        return z, pred, reconst, mu, logvar
    
class UnsupervisedVAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn):
        super(UnsupervisedVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
        self.decoder = Decoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
        self.mu = torch.nn.Linear(hidden_dim, latent_dim)
        self.logvar = torch.nn.Linear(hidden_dim, latent_dim)
    
    def reparameterize(self, mu, logvar):
        std = torch.sqrt(torch.exp(logvar))
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        reconst = self.decoder(z)
        return z, reconst, mu, logvar
    
# Check out these github repos for how to code the model and the loss function :

# 1. https://github.com/jariasf/GMVAE/tree/master
# 2. https://github.com/RuiShu/vae-clustering
    
class GMVAE(torch.nn.Module):
    def __init__(self, input_dim, y_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn):
        super(GMVAE, self).__init__()
        """
        A GMVAE has three main modules:
        q(y|x) : Predict the class label based on X
        q(z|y,x) : Predict the latent variable based on X and the class label
        """
        self.encoder = Encoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
        self.decoder = Decoder(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
        try:
            self.activation_fn = activation_fns[activation_fn]
        except KeyError:
            raise ValueError('Invalid activation function')
        self.activation_fn = activation_fns[activation_fn]
        self.input_dim = input_dim
        self.y_dim = y_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        self.qy_logit_x_layers = torch.nn.ModuleList()
        if self.num_layers == 1:
            self.qy_logit_x_layers.append(torch.nn.Linear(self.input_dim, self.y_dim))
            self.qy_logit_x_layers.append(self.activation_fn)
        else:
            for i in range(self.num_layers):
                if i == 0:
                    self.qy_logit_x_layers.append(torch.nn.Linear(self.input_dim, self.hidden_dim))
                    self.qy_logit_x_layers.append(self.activation_fn)
                    self.qy_logit_x_layers.append(torch.nn.Dropout(self.dropout))
                elif i == self.num_layers - 1:
                    self.qy_logit_x_layers.append(torch.nn.Linear(self.hidden_dim, self.y_dim))
                else:
                    self.qy_logit_x_layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                    self.qy_logit_x_layers.append(self.activation_fn)
                    self.qy_logit_x_layers.append(torch.nn.Dropout(self.dropout))

        self.mu = torch.nn.Linear(self.hidden_dim, self.latent_dim)
        self.logvar = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_dim, self.latent_dim),
            torch.nn.Softplus()
        )
        self.mu_prior = torch.nn.Linear(self.y_dim, self.latent_dim)
        self.logvar_prior = torch.nn.Sequential(
            torch.nn.Linear(self.y_dim, self.latent_dim),
            torch.nn.Softplus()
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.sqrt(torch.exp(logvar))
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        for i, qy_logit_x_layer in enumerate(self.qy_logit_x_layers):
            if i == 0:
                qy_logit = qy_logit_x_layer(x)
            else:
                qy_logit = qy_logit_x_layer(qy_logit)
        qy = torch.nn.Softmax(dim=1)(qy_logit)

        # Defining a tensor that will store the fixed class label for all members of the batch
        y_ = torch.zeros([x.shape[0], self.y_dim])
        z, pred, reconst, mu, logvar, mu_prior, logvar_prior = [[None] * 10 for i in range(7)]
        for i in range(self.y_dim):
            # Add the class label to the tensor
            y = y_ + torch.eye(self.y_dim)[i]
            # Note to self : The generative model can take the predicted class label as input. This is what is done in the GMVAE repo
            # Note to self : In the Rui Shu repo the class label (y) is provided as a one hot vector. 
            h = torch.cat([x, y], dim=1)
            for j, encoder_layer in enumerate(self.encoder_layers):
                if j == 0:
                    h = encoder_layer(h)
                else:
                    h = encoder_layer(h)
            mu[i] = self.mu(h)
            logvar[i] = self.logvar(h)
            # Note to self : Can use the reparameterization trick here instead. This gives modified M2 in Rui Shu's repo.
            # Using the predicted mean and logvar sample from a gaussian distribution.
            # z[i] = torch.normal(mu[i], logvar[i].exp().sqrt())
            z[i] = self.reparameterize(mu[i], logvar[i])
            for j, pred_layer in enumerate(self.predictor_layers):
                if j == 0:
                    pred[i] = pred_layer(z[i])
                else:
                    pred[i] = pred_layer(pred[i])    
            mu_prior[i] = self.mu_prior(y)
            logvar_prior[i] = self.logvar_prior(y)
            for j, decoder_layer in enumerate(self.decoder_layers):
                if j == 0:
                    reconst[i] = decoder_layer(z[i])
                else:
                    reconst[i] = decoder_layer(reconst[i])
        return z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy

In [4]:
# Test the simple autoencoder
randX = torch.rand(2, 2)
# self, input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn
simple_ae = UnsupervisedSimpleAE(15, 10, 0.01, 2, 3, 'tanh', None)
print(simple_ae.decoder(randX))

(tensor([[-0.1194,  0.0835, -0.2836,  0.1417,  0.0063],
        [-0.1277,  0.0807, -0.2891,  0.1496,  0.0049]],
       grad_fn=<AddmmBackward0>), tensor([[-0.1037,  0.2275,  0.3310,  0.1295,  0.2166],
        [-0.0951,  0.2288,  0.3349,  0.1155,  0.2255]],
       grad_fn=<AddmmBackward0>), tensor([[-0.1771, -0.0429, -0.1416,  0.1270,  0.1181],
        [-0.1824, -0.0526, -0.1386,  0.1236,  0.1124]],
       grad_fn=<AddmmBackward0>), tensor([[-0.4547,  0.1833,  0.3178,  0.0128],
        [-0.4551,  0.1840,  0.3242,  0.0161]], grad_fn=<AddmmBackward0>))


### Deriving the KL divergence loss for unit normal prior

- Lets start with any arbitrary distribution Q(z) and minimize the KL divergence of it with a distribution P(z|X)

$$
\begin{align*}
    D_{KL}(Q(z) || P(z|X)) & = \int Q(z) \log \frac{Q(z)}{P(z|X)} dz \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log \frac{Q(z)}{P(z|X)} ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(z|X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z) + \log P(X) ] \\ \\
    D_{KL}(Q(z) || P(z|X)) & = E_{Q(z)}[ \log Q(z) - \log P(X|z) - \log P(z)] + \log P(X) \\ \\
    \log P(X) - D_{KL}(Q(z) || P(z|X)) & =  E_{Q(z)}[\log P(X|z)] - D_{KL}(Q(z) || P(z)) 
\end{align*}
$$

- Now instead of choosing any distribution for Q(z), it makes sense to choose a distribution for the z variables that depends on X. Hence we can replace Q(z) with Q(z|X).

$$
\begin{align*}
    \log P(X) - D_{KL}(Q(z|X) || P(z|X)) & =  E_{Q(z|X)}[\log P(X|z)] - D_{KL}(Q(z|X) || P(z))
\end{align*}
$$

- The left hand side contains the terms that we want to maximize. The log probability density of X and an error term that measures the deviation between the approximate distribution (Q(z|X)) and the true probability distribution (P(z|X)). To note P(X) is a high dimensional intractable distribution and we don't have access to P(z|X). By having a large enough capacity for Q(z|X) we are pulling it closer to P(z|X), lower the KL divergence term until we are only optimizing for the log probability density of X. 
- The right hand side contains terms that can be optimized via gradient descent. The first term is the expected value of the log likelihood of the data given the latent variables. The second term is the KL divergence between the approximate distribution and the prior distribution. 
- Stochastic gradient descent can be performed on the right hand side by assuming some forms of the distribution. The most common form for the posterior and liklihood is a multivariate Gaussian distribution and for the prior is unit normal distribution. 

$$
\begin{align*}
    D_{KL}(N(\mu_0, \Sigma_0) || N(\mu_1, \Sigma_1)) = \frac{1}{2} ( \text{tr}(\Sigma_1^{-1} \Sigma_0) + (\mu_1 - \mu_0)^T \Sigma_1^{-1} (\mu_1 - \mu_0) - k + \log \frac{\det \Sigma_1}{\det \Sigma_0} )
\end{align*}
$$

- 'k' is the dimensionality of the distribution. Substituting the prior as unit normal distribution, we get the KL divergence loss as
$$
\begin{align*}
    D_{KL}(N(\mu (X), \Sigma (X)) || N(O, I)) = \frac{1}{2} ( \text{tr}(\Sigma (X)) + (\mu (X))^T (\mu (X)) - k - \log \det \Sigma (X) )
\end{align*}
$$

- To back propagate the errors to the the neural network that approximates Q(z|X), so that we get z's that correctly reproduce the data, we need to find a way that allows backpropagation to work. This is where the reparameterization trick comes in. It allows us to sample for 'z' while giving access to the neural networks that approximate the mean and covariance functions for  Q(z|X). $ z = \mu (X) + \Sigma (X) * \epsilon $. Here $\mu (X) and \Sigma (X)$ are approximated by using neural networks and $\epsilon$ is sampled from the unit normal distribution.
- If any other distribution is to be modelled then the KL divergnce term must be modified accordingly and the appropriate reparameterization trick must be used.

Reference:
- Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 https://arxiv.org/abs/1312.6114 (Appendix B)
- Doersch, C. Tutorial on Variational Autoencoders. arXiv January 3, 2021. http://arxiv.org/abs/1606.05908.


## Model Optimization

In [5]:
# Training loop

def log_normal(z, mu, logvar):
    c = torch.tensor(2*np.pi, dtype=torch.float32) 
    return torch.tensor(-0.5, dtype=torch.float32)*torch.sum(torch.log(c) + logvar + (z - mu).pow(2) / logvar.exp(), dim=1)

# KL divergence loss
def labelled_loss(z, mu, logvar, mu_prior, logvar_prior):
    c = torch.tensor(0.1, dtype=torch.float32)
    return log_normal(z, mu, logvar) - log_normal(z, mu_prior, logvar_prior) - torch.log(c)

# Derived by assuming posterior is Gaussian and prior is unit normal distribution.
def kl_divergence_loss_fn(mu, logvar):
        return torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)

def train(train_subsampler, val_subsampler, model_type, run_dir, sweep_hyperparams=True, save_model=False, project_name=None, model_name=None, config=None):
    if sweep_hyperparams:
        run = wandb.init(job_type='training', resume=False, reinit=False, config=sweep_config)
        config = wandb.config
        input_dim = config.input_dim
        hidden_dim = config.hidden_dim
        y_dim = config.y_dim
        dropout = config.dropout
        l1_reg = config.l1_reg
        l2_reg = config.l2_reg
        latent_dim = config.latent_dim
        num_layers = config.num_layers
        activation_fn = config.activation_fn
        pred_activation_fn = config.pred_activation_fn
        dec_activation_fn = config.dec_activation_fn
        lr = config.learning_rate
        epochs = config.epochs
        batch_size = config.batch_size
    else:
        run = wandb.init(project=project_name, name=model_name, job_type='training', resume=False, reinit=False, config=config)
        input_dim = config['input_dim']['value']
        hidden_dim = config['hidden_dim']['value']
        y_dim = config['y_dim']['value']
        dropout = config['dropout']['value']
        l1_reg = config['l1_reg']['value']
        l2_reg = config['l2_reg']['value']
        latent_dim = config['latent_dim']['value']
        num_layers = config['num_layers']['value']
        activation_fn = config['activation_fn']['value']
        pred_activation_fn = config['pred_activation_fn']['value']
        dec_activation_fn = config['dec_activation_fn']['value']
        lr = config['learning_rate']['value']
        epochs = config['epochs']['value']
        batch_size = config['batch_size']['value']
    
    if model_type == 'SupervisedSimpleAE':
        model = SupervisedSimpleAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, pred_activation_fn, dec_activation_fn)
    elif model_type == 'UnsupervisedSimpleAE':
        model = UnsupervisedSimpleAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn, dec_activation_fn)
    elif model_type == 'SupervisedVAE':
        model = SupervisedVAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    elif model_type == 'UnsupervisedVAE':
        model = UnsupervisedVAE(input_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    elif model_type == 'GMVAE':
        model = GMVAE(input_dim, y_dim, hidden_dim, dropout, latent_dim, num_layers, activation_fn)
    else:
        raise ValueError('Invalid model type')
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train_dataloader = torch.utils.data.DataLoader(train_subsampler, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_subsampler, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        train_kl_loss_per_step = []
        train_recon1_loss_per_step = []
        train_recon2_loss_per_step = []
        train_recon3_loss_per_step = []
        train_recon4_loss_per_step = []
        train_pred_loss_per_step = []
        train_total_loss_per_step = []
        val_kl_loss_per_step = []
        val_recon1_loss_per_step = []
        val_recon2_loss_per_step = []
        val_recon3_loss_per_step = []
        val_recon4_loss_per_step = []
        val_pred_loss_per_step = []
        val_total_loss_per_step = []
        # for i, input in enumerate(train_dataloader):
        # for i, (y, input) in enumerate(train_dataloader):
        for i, (input1, input2, input3, input4) in enumerate(train_dataloader):
            optimizer.zero_grad()

            # Model 1
            _, recon1, recon2, recon3, recon4 = model(torch.concat((input1, input2, input3, input4), dim=1))
            # _, pred, reconst = model(input)
            # _, reconst = model(input)
            # train_loss1 = torch.nn.L1Loss(reduction='mean')(reconst1, input1)
            # train_recon1_loss_per_step.append(train_loss1.item())
            train_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(recon1, input1)
            train_recon1_loss_per_step.append(train_loss1.item())

            train_loss2 = torch.nn.CrossEntropyLoss(reduction='mean')(recon2, input2)
            train_recon2_loss_per_step.append(train_loss2.item())

            train_loss3 = torch.nn.CrossEntropyLoss(reduction='mean')(recon3, input3)
            train_recon3_loss_per_step.append(train_loss3.item())

            train_loss4 = torch.nn.L1Loss(reduction='mean')(recon4, input4)
            train_recon4_loss_per_step.append(train_loss4.item())

            # train_loss3 = torch.nn.L1Loss(reduction='mean')(y, pred)
            # train_pred_loss_per_step.append(train_loss3.item())

            enc_params = torch.cat([x.view(-1) for x in model.encoder.parameters()])
            # pred_params = torch.cat([x.view(-1) for x in model.predictor.parameters()])
            dec_params = torch.cat([x.view(-1) for x in model.decoder.parameters()])

            l1_regularization = l1_reg * (torch.norm(enc_params, 1) + 
                                          torch.norm(dec_params, 1))
            l2_regularization = l2_reg * (torch.norm(enc_params, 2) + 
                                          torch.norm(dec_params, 2))
            
            train_total_loss = train_loss1 + train_loss2 + train_loss3 + train_loss4 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + train_loss2 + train_loss3 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + train_loss2 + l1_regularization + l2_regularization
            # train_total_loss = train_loss1 + l1_regularization + l2_regularization
            train_total_loss_per_step.append(train_total_loss.item())

            # # Model 2
            # z, pred, reconst, mu, logvar = model(input)
            # train_loss1 = kl_divergence_loss_fn(mu, logvar)
            # train_kl_loss_per_step.append(train_loss1.item())
            # train_loss2 = torch.nn.MSELoss(reduction='mean')(input, reconst)
            # train_reconst_loss_per_step.append(train_loss2.item())
            # train_loss3 = torch.nn.MSELoss(reduction='mean')(bandgaps.unsqueeze(dim=1), pred)
            # train_pred_loss_per_step.append(train_loss3.item())
            # train_total_loss = train_loss1 + train_loss2 + train_loss3
            # train_total_loss_per_step.append(train_total_loss.item())

            # # Model 3
            # z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(input)
            # train_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(qy_logit, qy)
            # train_loss2 = [None] * model.y_dim
            # train_loss3 = [None] * model.y_dim
            # train_loss4 = [None] * model.y_dim
            # for i in range(model.y_dim):
            #     # Take mean across the batch
            #     train_loss2[i] = torch.mean(qy[:, i]*torch.nn.L1Loss(reduction='sum')(bandgaps.unsqueeze(dim=1), pred[i]), dtype=torch.float32)
            #     train_loss3[i] = torch.mean(qy[:, i]*torch.nn.MSELoss(reduction='sum')(input, reconst[i]), dtype=torch.float32)
            #     train_loss4[i] = torch.mean(qy[:, i]*labelled_loss(z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            # train_pred_loss_per_step.append(torch.stack(train_loss2).sum().item())
            # train_reconst_loss_per_step.append(torch.stack(train_loss3).sum().item())
            # train_kl_loss_per_step.append(torch.stack(train_loss4).sum().item())
            # train_total_loss = train_loss1 + torch.stack(train_loss2).sum() + torch.stack(train_loss3).sum() + torch.stack(train_loss4).sum()
            # train_total_loss_per_step.append(train_total_loss.item())

            train_total_loss.backward()
            optimizer.step()
        # if model_type == 'GMVAE':
        #     wandb.log({'epoch':epoch, 'train_kl_loss_per_epoch':np.mean(train_kl_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_recon1_loss_per_epoch':np.mean(train_recon1_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_recon2_loss_per_epoch':np.mean(train_recon2_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_recon3_loss_per_epoch':np.mean(train_recon3_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_recon4_loss_per_epoch':np.mean(train_recon4_loss_per_step)})
        # wandb.log({'epoch':epoch, 'train_pred_loss_per_epoch':np.mean(train_pred_loss_per_step)})
        wandb.log({'epoch':epoch, 'train_total_loss_per_epoch':np.mean(train_total_loss_per_step)})

        # Run the validation loop
        # for i, input in enumerate(val_dataloader):
        # for i, (y, input) in enumerate(val_dataloader):
        # for i, (y, input1, input2) in enumerate(val_dataloader):
        for i, (input1, input2, input3, input4) in enumerate(val_dataloader):
            # Model 1
            _, recon1, recon2, recon3, recon4 = model(torch.concat((input1, input2, input3, input4), dim=1))
            # _, pred, reconst1, reconst2 = model(torch.concat((input1, input2), dim=1))
            # _, pred, reconst = model(input)
            # _, reconst = model(input)

            # val_loss1 = torch.nn.L1Loss(reduction='mean')(recon1, input1)
            # val_recon1_loss_per_step.append(val_loss1.item())
            val_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(recon1, input1)
            val_recon1_loss_per_step.append(val_loss1.item())

            val_loss2 = torch.nn.CrossEntropyLoss(reduction='mean')(recon2, input2)
            val_recon2_loss_per_step.append(val_loss2.item())

            val_loss3 = torch.nn.CrossEntropyLoss(reduction='mean')(recon3, input3)
            val_recon3_loss_per_step.append(val_loss3.item())

            val_loss4 = torch.nn.L1Loss(reduction='mean')(recon4, input4)
            val_recon4_loss_per_step.append(val_loss4.item())

            # val_loss3 = torch.nn.L1Loss(reduction='mean')(y, pred)
            # val_pred_loss_per_step.append(val_loss3.item())

            val_total_loss = val_loss1 + val_loss2 + val_loss3 + val_loss4
            val_total_loss_per_step.append(val_total_loss.item())
            # val_total_loss = val_loss1 + val_loss2
            # val_total_loss_per_step.append(val_total_loss.item())
            # val_total_loss = val_loss1
            # val_total_loss_per_step.append(val_total_loss.item())

            # # Model 2
            # z, pred, reconst, mu, logvar = model(input)
            # val_loss1 = kl_divergence_loss_fn(mu, logvar)
            # val_kl_loss_per_step.append(val_loss1.item())
            # val_loss2 = torch.nn.MSELoss(reduction='mean')(input, reconst)
            # val_reconst_loss_per_step.append(val_loss2.item())
            # val_loss3 = torch.nn.MSELoss(reduction='mean')(bandgaps.unsqueeze(dim=1), pred)
            # val_pred_loss_per_step.append(val_loss3.item())
            # val_total_loss = val_loss1 + val_loss2 + val_loss3
            # val_total_loss_per_step.append(val_total_loss.item())

            # # Model 3
            # z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(input)
            # val_loss1 = torch.nn.CrossEntropyLoss(reduction='mean')(qy_logit, qy)
            # val_loss2 = [None] * model.y_dim
            # val_loss3 = [None] * model.y_dim
            # val_loss4 = [None] * model.y_dim
            # for i in range(model.y_dim):
            #     # Take mean across batch
            #     val_loss2[i] = torch.mean(qy[:, i]*torch.nn.L1Loss(reduction='sum')(bandgaps.unsqueeze(dim=1), pred[i]), dtype=torch.float32)
            #     val_loss3[i] = torch.mean(qy[:, i]*torch.nn.MSELoss(reduction='sum')(input, reconst[i]), dtype=torch.float32)
            #     val_loss4[i] = torch.mean(qy[:, i]*labelled_loss(z[i], mu[i], logvar[i], mu_prior[i], logvar_prior[i]), dtype=torch.float32)
            # val_pred_loss_per_step.append(torch.stack(val_loss2).sum().item())
            # val_reconst_loss_per_step.append(torch.stack(val_loss3).sum().item())
            # val_kl_loss_per_step.append(torch.stack(val_loss4).sum().item())
            # val_total_loss = val_loss1 + torch.stack(val_loss2).sum() + torch.stack(val_loss3).sum() + torch.stack(val_loss4).sum()
            # val_total_loss_per_step.append(val_total_loss.item())

        # if model_type == 'GMVAE':
        #     wandb.log({'epoch':epoch, 'val_kl_loss_per_epoch':np.mean(val_kl_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_recon1_loss_per_epoch': np.mean(val_recon1_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_recon2_loss_per_epoch': np.mean(val_recon2_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_recon3_loss_per_epoch': np.mean(val_recon3_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_recon4_loss_per_epoch': np.mean(val_recon4_loss_per_step)})
        # wandb.log({'epoch':epoch, 'val_pred_loss_per_epoch': np.mean(val_pred_loss_per_step)})
        wandb.log({'epoch':epoch, 'val_total_loss_per_epoch': np.mean(val_total_loss_per_step)})
    run.finish()
    # Save the model
    if save_model:
        model_dir = f'../runs/{run_dir}/{model_name}'
        # if model dir does not exist create it
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        model_path = model_dir + '/' + f'{model_name}.pth'
        torch.save(model.state_dict(), model_path)
        # Save the scalers also in the model folder
        # scaler_X1_path = f'../runs/perovskite_multiscale_dataset_v2/{model_name}/scaler_X1.pkl'
        scaler_X4_path = f'../runs/perovskite_multiscale_dataset_v2/{model_name}/scaler_latentsfromAE1.pkl'
        scaler_y_path = f'../runs/perovskite_multiscale_dataset_v2/{model_name}/scaler_y.pkl'
        joblib.dump(scaler_X4, scaler_X4_path)
        if model_type == 'SupervisedSimpleAE':
            pass
            # joblib.dump(scaler_y, scaler_y_path)
    
# If sweeping for hyperparams replace 'value' with 'values' 
    
# USER INPUT HERE
# -----------------------------------------------------------------
# By default save_model is set to False for hyperparam runs.
sweep_for_hyperparams = True
model_name = 'UnsupSimpleAE_2_ldim4_peroveff2'
run_dir = 'perovskite_multiscale_dataset_v2'
save_model = False
model_type = 'UnsupervisedSimpleAE'
sweep_name = 'UnsupSimpleAE_2_ldim4_peroveff2'
sweep_type = 'grid' # Select between 'bayes', 'grid', 'random' 
limit_num_trials_in_sweep = None # Typically 3*3*3=27 trials. Consider only (10% of space is explored) 0.1*81=8.1 ~ 8 trials
num_folds = 5 # 494/5 = 98.8 in internal validation set
val_split = 0.2 # num_folds*val_split must be 1 
parameters = {
            'input_dim':{
                'value':X1_final.shape[1] + X2_final.shape[1] + X3_final.shape[1] + X4_final.shape[1]
            },
            'hidden_dim':{
                'values':[25, 50, 75]
            },
            'latent_dim':{
                'value':4
            },
            'y_dim':{
                'value':None
            },
            'dropout':{
                'value':0
            },
            'l1_reg':{
                'value':0
            },
            'l2_reg':{
                'value':0.001
            },
            'num_layers':{
                'values':[1, 2, 3]
            },
            'activation_fn':{
                'values':['tanh', 'relu', None]
            },
            'dec_activation_fn':{
                'value':None
            },
            'pred_activation_fn':{
                'value':None
            },
            'batch_size':{
                'value':10
            },
            'learning_rate':{
                'value':1e-3
            },
            'epochs':{
                'value':1000
            }
        }

# parameters = {
#             'input_dim':{
#                 'value':X1.shape[1] + X2.shape[1]
#             },
#             'hidden_dim':{
#                 'value':50,
#             },
#             'latent_dim':{
#                 'value':4
#             },
#             'y_dim':{
#                 'value':0
#             },
#             'dropout':{
#                 'value':0
#             },
#             'l1_reg':{
#                 'value':0
#             },
#             'l2_reg':{
#                 'value':0.001
#             },
#             'num_layers':{
#                 'value':3
#             },
#             'activation_fn':{
#                 'value':'relu'
#             },
#             'dec_activation_fn':{
#                 'value':None
#             },
#             'pred_activation_fn':{
#                 'value':'relu'
#             },
#             'batch_size':{
#                 'value':1
#             },
#             'learning_rate':{
#                 'value':1e-3
#             },
#             'epochs':{
#                 'value':1000
#             }
#         }

# ----------------------------------------------------------------- 

if sweep_for_hyperparams:
    sweep_config = {
        'name':sweep_name,
        'method':sweep_type,
        'metric':{
            'name':'val_total_loss_per_epoch',
            'goal':'minimize'
            },
        'parameters':parameters
    }
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=seed)
    for fold, (train_indices, val_indices) in enumerate(kf.split(dataset)):
        train_subsampler = torch.utils.data.Subset(dataset, train_indices)
        val_subsampler = torch.utils.data.Subset(dataset, val_indices)
        sweep_config['name'] = sweep_config['name'] + '_fold_' + str(fold)
        sweep_id = wandb.sweep(sweep_config, project=project_name)
        run_name = wandb.util.generate_id()
        wandb.agent(sweep_id, lambda: train(train_subsampler, val_subsampler, model_type=model_type, run_dir=run_dir), project=project_name, count=limit_num_trials_in_sweep)
        # Finish the sweep
        wandb.finish()
else:
    num_samples = X1.shape[0]
    indices = np.arange(num_samples)
    np.random.shuffle(indices)
    train_indices = indices[:int((1- val_split)*num_samples)]
    val_indices = indices[int((1- val_split)*num_samples):]
    train_subsampler = torch.utils.data.Subset(dataset, train_indices)
    val_subsampler = torch.utils.data.Subset(dataset, val_indices)
    train(train_subsampler, val_subsampler, model_type=model_type, run_dir=run_dir, sweep_hyperparams=False, save_model=save_model, project_name=project_name, model_name=model_name, config=parameters)


Create sweep with ID: qlalchu1
Sweep URL: https://wandb.ai/nthota2/perovskite_dataset_v2/sweeps/qlalchu1


[34m[1mwandb[0m: Agent Starting Run: ilh31zun with config:
[34m[1mwandb[0m: 	activation_fn: tanh
[34m[1mwandb[0m: 	batch_size: 10
[34m[1mwandb[0m: 	dec_activation_fn: None
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	epochs: 1000
[34m[1mwandb[0m: 	hidden_dim: 25
[34m[1mwandb[0m: 	input_dim: 19
[34m[1mwandb[0m: 	l1_reg: 0
[34m[1mwandb[0m: 	l2_reg: 0.001
[34m[1mwandb[0m: 	latent_dim: 4
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	num_layers: 1
[34m[1mwandb[0m: 	pred_activation_fn: None
[34m[1mwandb[0m: 	y_dim: None
[34m[1mwandb[0m: Currently logged in as: [33mnthota2[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Create sweep with ID: il9xkni0
Sweep URL: https://wandb.ai/nthota2/perovskite_dataset_v2/sweeps/il9xkni0


<IPython.core.display.HTML object>
<IPython.core.display.HTML object>


Exception in thread Exception in threading.excepthook:
Exception ignored in thread started by: <bound method Thread._bootstrap of <Thread(Thread-6, stopped 11492470784)>>
Traceback (most recent call last):
  File "/Users/nikhilthota/miniconda3/envs/nestedae/lib/python3.9/threading.py", line 937, in _bootstrap
    self._bootstrap_inner()
  File "/Users/nikhilthota/miniconda3/envs/nestedae/lib/python3.9/threading.py", line 982, in _bootstrap_inner
    self._invoke_excepthook(self)
  File "/Users/nikhilthota/miniconda3/envs/nestedae/lib/python3.9/threading.py", line 1264, in invoke_excepthook
    local_print("Exception in threading.excepthook:",
  File "/Users/nikhilthota/miniconda3/envs/nestedae/lib/python3.9/site-packages/ipykernel/iostream.py", line 559, in flush
    self.pub_thread.schedule(self._flush)
  File "/Users/nikhilthota/miniconda3/envs/nestedae/lib/python3.9/site-packages/ipykernel/iostream.py", line 251, in schedule
    self._event_pipe.send(b"")
  File "/Users/nikhilthota/

## Model Interpretability

#### Plotting model performance vs latent dimension

In [None]:
# For RL paper
# # Grid data - UnsupervisedSimpleAE 1
# latent_space = [1, 2, 4, 6, 8]
# total_val_loss = [0.001025053068588022, 0.001063714546035044, 0.0011066086881328374, 0.0013437549932859838, 0.0009386805177200586]
# total_train_loss = [0.007932848995551467, 0.008279352798126638, 0.004237618821207434, 0.004374776501208544, 0.004906855116132647]

# # Grid data - Nonlinf5 SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6]
# total_val_loss = [0.7463283464312553, 0.2002287097275257, 0.1977713629603386, 0.1788929458707571]
# total_train_loss = [0.4705591835081578, 0.14468571869656444, 0.1250611103605479, 0.13816878804937005]

# # Grid data - Sumf5 - SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6]
# total_val_loss = [0.49156802892684937, 0.019491535145789385, 0.02250012196600437, 0.02030603913590312]
# total_train_loss = [0.4140172880142927, 0.03587072214577347, 0.02835194836370647, 0.03037486143875867]

# # Random data - UnsupervisedSimpleAE 1
# latent_space = [2, 4, 6, 8]
# total_val_loss = [0.6480237692594528, 0.4781555384397506, 0.24337586015462875, 0.002268550335429609]
# total_train_loss = [0.6661303304135799, 0.4441600125283003, 0.20833437889814377, 0.007954166503623128]

# # Random data - Nonlinf5 - SupervisedSimpleAE 2
# latent_space = [1, 2, 4, 6, 8, 10, 11, 12]
# total_val_loss = [0.7916551381349564, 0.6637141406536102, 0.5525364577770233, 0.3102440983057022, 0.19166426360607147, 0.051493472419679165, 0.021430929424241185, 0.007555923308245838]
# total_train_loss = [0.7504490427672863, 0.6266670003533363, 0.4413919039070606, 0.30615816451609135, 0.18379461765289307, 0.05851957411505282, 0.028894496499560773, 0.016192137030884624]

# # Random data - Sumf5 - SupervisedSimpleAE 2
# # Will have to change the predictor as only prediction loss is high, reconstruction is good.
# latent_space = [1, 2, 4, 6, 8, 10, 11, 12]
# total_val_loss = [1.4578111469745636, 1.4702374935150146, 1.383190006017685, 1.16259, 0.99866, 0.81431, 0.75846, 0.83324]
# total_train_loss = [1.4289479702711103, 1.2744147181510923, 1.0854754857718945, 0.96673, 0.84347, 0.72946, 0.69708, 0.66834]

# Arun2024
latent_space =      [2,     4,     6,     8,     10,    12,    14]
pred_val_loss =     [0.206, 0.178, 0.244, 0.227, 0.148, 0.233, 0.231]
pred_train_loss =   [0.11,  0.076, 0.082, 0.067, 0.074, 0.071, 0.067]
recont_val_loss =   [0.561, 0.069, 0.055, 0.050, 0.038, 0.046, 0.056]
recont_train_loss = [0.234, 0.055, 0.037, 0.038, 0.037, 0.042, 0.039]
# total_val_loss = [0.76786, 0.24755343379718917, 	0.24418, 0.27646, 0.18622, 0.27943, 0.28719]
# total_train_loss = [0.38542, 0.16735203887741917, 	0.15132, 0.13772, 0.14288, 0.14528, 0.13803]

# for 10 this is the best that can be done. Adding more layers or increasing the hidden dim or using non linea act does not work.
# val : 0.7632711380720139
# train : 0.7363427169620991

# Total val loss vs SimpleAE latent space variation
# latent_space = [2, 3, 4, 8, 12]
# total_val_loss = [0.5776124358177185, 0.5138478517532349, 0.4121862232685089, 0.2896019071340561, 0.2887373447418213]

# Plot the latent space vs total validation loss
plt.figure(figsize=(7, 5))
plt.plot(latent_space, recont_val_loss, marker='o', linestyle='', label='val')
plt.plot(latent_space, recont_train_loss, marker='o', linestyle='', label='train')
plt.xlabel('Latent space dimension')
plt.ylabel('Reconst. MAE')
plt.legend()
plt.show()

#### Loading the torch model

In [None]:
# Load torch model
model_name = 'best_SupSimpleAE_1_ldim4_arun2024'
run_dir = 'perovskite_multiscale_dataset_v2'
model_path = f'../runs/{run_dir}/{model_name}/' + model_name + '.pth'
# model = UnsupervisedSimpleAE(parameters['input_dim']['value'],
#                             parameters['hidden_dim']['value'], 
#                             parameters['dropout']['value'], 
#                             parameters['latent_dim']['value'], 
#                             parameters['num_layers']['value'], 
#                             parameters['activation_fn']['value'],
#                             parameters['dec_activation_fn']['value'])
model = SupervisedSimpleAE(parameters['input_dim']['value'],
                            parameters['hidden_dim']['value'], 
                            parameters['dropout']['value'], 
                            parameters['latent_dim']['value'], 
                            parameters['num_layers']['value'], 
                            parameters['activation_fn']['value'],
                            parameters['pred_activation_fn']['value'],
                            parameters['dec_activation_fn']['value'])
# model = VAE(parameters['input_dim']['value'],
#             parameters['hidden_dim']['value'], 
#             parameters['dropout']['value'], 
#             parameters['latent_dim']['value'], 
#             parameters['num_layers']['value'], 
#             parameters['activation_fn']['value'])
# model = GMVAE(parameters['input_dim']['value'],
#                 parameters['y_dim']['value'],
#                 parameters['hidden_dim']['value'], 
#                 parameters['dropout']['value'], 
#                 parameters['latent_dim']['value'], 
#                 parameters['num_layers']['value'], 
#                 parameters['activation_fn']['value'])
model.load_state_dict(torch.load(model_path))

#### Save the latents to .csv file

In [None]:
# Save the latents to .csv file
z, pred, _, _ = model(torch.cat((torch.from_numpy(X1_final), torch.from_numpy(X2_final)), dim=1))
latent_filename = 'latents_from_PSC_efficiencies_dataset_2'

# Data folder paths and file names
dataset_folder_name = 'synthetic_dataset'
dataset_file_name = 'synthetic_data_gridSamples_200_sumf5.csv'
new_dataset_file_name = 'synthetic_data_gridSamples_200_sumf5_with_ae1_latents_concat.csv'
run_folder_name = 'perovskite_multiscale_dataset_v2'
model_folder_name = 'best_SupSimpleAE_1_ldim4_arun2024'
AE_number = '1'

dataset_folder = '../datasets/' + dataset_folder_name
dataset_file = dataset_folder + '/' + dataset_file_name
concatenated_dataset_file = dataset_folder + '/' + new_dataset_file_name

run_folder = '../runs/' + run_dir
latent_file = run_folder + '/' + model_name + '/' + latent_filename + '.csv'

# Save the latents to .csv file
z_df = pd.DataFrame(z.detach().numpy())
pred_df = pd.DataFrame(pred.detach().numpy())
# Conccaetnate the latents and the predictions
z_pred_df = pd.concat([z_df, pred_df], axis=1)
z_pred_df.to_csv(latent_file, index=False, header=False) 

# latents = pd.read_csv(latent_file, header=None, skiprows=None)
# data = pd.read_csv(dataset_file)

# for i in range(len(latents.columns)):
#     data['AE'+ AE_number +'_latent_'+str(i)] = latents[i]

# data.to_csv(concatenated_dataset_file, index=False)

In [None]:
# # Find which 'z'is closest to the 'z' provided below
# z_query = torch.tensor([0.26077228, 3.95919058, 0., 0., 0., 0., 1.19784881, 1.33860231], dtype=torch.float32)
# z_query = z_query.unsqueeze(dim=0)
# z_query = z_query.repeat(z.shape[0], 1)
# dist = torch.nn.PairwiseDistance(p=1)
# distances = dist(z, z_query)
# closest_idx = torch.argmin(distances).item()
# print(torch.min(distances))
# print(f'Closest z to the query z is at index : {closest_idx}')
# print(f'Closest z to the query z is : {z[closest_idx]}')
# print(f'Reconst for query z is : {reconst[closest_idx]}')
# print(f'INvert scaling for reconst : {scaler_X.inverse_transform(reconst[closest_idx].detach().numpy().reshape(1, -1))}')

In [None]:
input = torch.tensor(X_scaled_np32[1, :], requires_grad=True)
print(input)
baseline = torch.zeros_like(input)
print(baseline)
print(model(baseline))

#### Feature Importance

In [None]:
from captum.attr import IntegratedGradients, DeepLift, InputXGradient, Saliency
import seaborn as sns

def model_wrapper(input):
    # z, pred, reconst = model(input)
    z, pred, reconst1, reconst2 = model(input)
    return pred

# Create an instance of the IntegratedGradients class
sal = Saliency(model_wrapper)
ig = IntegratedGradients(model_wrapper)
ixg = InputXGradient(model_wrapper)
dl = DeepLift(model_wrapper)
input = torch.cat((torch.from_numpy(X1_scaled_np32), torch.from_numpy(X2_np32)), dim=1)
# attr_sal = sal.attribute(input, target=None)
# attr_sal_np = attr_sal.detach().numpy()
# attr_ixg = ig.attribute(input, baselines=0, target=None)
# attr_ixg_np = attr_ixg.detach().numpy()
attr_ig = ig.attribute(input, baselines=0, target=-1)
attr_ig_np = attr_ig.detach().numpy()
# attr_dl = dl.attribute(input, baselines=0, target=None)
# attr_dl_np = attr_dl.detach().numpy()

fig, ax = plt.subplots(figsize=(8, 5))
# # Plot the attributions
using_boxplot = False
if using_boxplot:
    # Label the means on top of the bar in the boxplot
    for i in range(attr_ig_np.shape[1]):
        # plt.text(i+1, 1.5, '{:0.1f}'.format(np.mean(attr_ig_np[:, i])), ha='center', va='bottom')
        # Plot the standard deviation
        plt.text(i+1, 1.7, '{:0.1f}'.format(np.std(attr_ig_np[:, i])), ha='center', va='bottom')
    # Plot the means and standard deviations of the attributions for each feature
    plt.boxplot(attr_ig_np, showmeans=True, meanline=True)
    plt.title('Integrated Gradients calc. wrt pred')
    plt.show()
else:
    means = np.mean(attr_ig_np, axis=0)
    std_dev = np.std(attr_ig_np, axis=0)
    num_feats = attr_ig_np.shape[1]
    # ax.bar(np.arange(num_feats), means, yerr=std_dev, align='center', alpha=0.5, ecolor='black', capsize=10)
    ax.scatter(np.arange(num_feats), means, label='mean', color='r', marker='.')
    ax.errorbar(np.arange(num_feats), means, yerr=std_dev, fmt='o', capsize=5)
    for i in range(attr_ig_np.shape[1]):
        # plt.text(i+1, 1.5, '{:0.1f}'.format(np.mean(attr_ig_np[:, i])), ha='center', va='bottom')
        # Plot the standard deviation
        plt.text(i, 0.8, '{:0.1f}'.format(np.std(attr_ig_np[:, i])), ha='center', va='bottom')
    ax.set_xticks(np.arange(num_feats))
    ax.set_xticklabels(descriptors1 + descriptors2, rotation=90)
    plt.tight_layout()
    plt.show() 

# # Matrix Plot model weights
# for name, param in model.encoder.named_parameters():
#     if name == 'layers.0.weight':
#         ax = plt.figure(figsize=(9, 8))
#         ax = sns.heatmap(param.detach().numpy(), annot=True, fmt='.3f', cmap='coolwarm')
#         # Remove y axis labels
#         ax.yticks([])
#         ax.set_title('Encoder Layer 1 Weights')

In [None]:
# implementing own version of integrated gradients
input = torch.tensor(X_scaled_np32[1, :], requires_grad=True)
print(input)
baseline = torch.zeros_like(input)
print(baseline)
def interpolated_features(num_steps):
    alphas = torch.linspace(0, 1, num_steps+1)
    delta = input - baseline
    return torch.stack([baseline + alpha*delta for alpha in alphas])

def compute_gradients(interpolated_feats):
    grads = []
    for i in range(interpolated_feats.shape[0]):
        input = interpolated_feats[i]
        input = input.unsqueeze(dim=0)
        z, pred, reconst = model(input)
        print(pred.squeeze(dim=0))
        pred.backward()
        grads.append(input.grad)
        pred = model(interpolated_feats[i].unsqueeze(dim=0))
        pred.backward()
        grads.append(input.grad)
    return torch.stack(grads)

computed_grads = compute_gradients(interpolated_features(10))
plt.plot(torch.linspace(0, 1, 11), computed_grads)


#### 2D plots of latent space

In [None]:
# z, pred, reconst, mu, logvar = model(torch.from_numpy(elemental_properties))
# z, pred, reconst, mu, logvar, mu_prior, logvar_prior, qy_logit, qy = model(torch.from_numpy(elemental_properties))

# # Only for 2D latent plotting
# plt.scatter(z[:, 0].detach().numpy(), z[:, 1].detach().numpy(), c=pred.detach().numpy(), cmap='viridis', s=10, alpha=0.5)
# plt.colorbar()
# plt.show()

### Observations for Unit normal prior
- What we observe from the above example is that although multivariate Gaussian distribution are useful
    as each dimension can encode a separate DOF which results in representations that are sturctured and disentangled, 
    they are unimodal and hence cannot encode complex representations. A natural extension is to then use a different
    prior. Gaussain Mixture Model (GMM) is the next choice.
- Latent space is segregated into different classes.
- However, inference is non-trivial.