* This notebook experiments with new NestedAE architectures.
* Features that are useful are then pushed to the source code.

In [1]:
import math
import random
import copy

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import KFold
import torch
from torch.nn import Module, ModuleList, ModuleDict, Linear, L1Loss
from torch.utils.data import DataLoader
from torch.optim import SGD
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import KFold
import seaborn as sns
from scipy.stats import gaussian_kde
import umap

from NestedAE.nn_utils import check_dict_key_exists, set_layer_init

  from .autonotebook import tqdm as notebook_tqdm


#### Define the model

In [None]:
class AE(Module):
    def __init__(self, module_params):
        super(AE, self).__init__()
        ae_modules = {}
        # Outer loop iterates over the ae_modules
        for module_name, module_dict in module_params['modules'].items():
            layer_list = ModuleList()
            # Check for existence of keys or take defualts if not present
            if check_dict_key_exists('hidden_layers', module_dict):
                hidden_layers = module_dict['hidden_layers']
            else:
                hidden_layers = 0
            if check_dict_key_exists('hidden_dim', module_dict):
                hidden_dim = module_dict['hidden_dim']
            else:
                hidden_dim = None
            if check_dict_key_exists('hidden_activation', module_dict):
                hidden_activation = module_dict['hidden_activation']
            else:
                hidden_activation = None
            if check_dict_key_exists('output_activation', module_dict):
                output_activation = module_dict['output_activation']
            else:
                output_activation = None
            if check_dict_key_exists('layer_dropout', module_dict):
                layer_dropout = module_dict['layer_dropout']
            else:
                layer_dropout = None
            if check_dict_key_exists('layer_kernel_init', module_dict):
                layer_kernel_init = module_dict['layer_kernel_init']
            else:
                layer_kernel_init = None
            if check_dict_key_exists('layer_bias_init', module_dict):
                layer_bias_init = module_dict['layer_bias_init']
            else:
                layer_bias_init = None
            if check_dict_key_exists('load_params', module_dict):
                load_params = module_dict['load_params']
            else:
                load_params = False

            num_layers = hidden_layers + 1
            for layer_num in range(num_layers):
                if layer_num == 0:
                    # Calculate the input dimensions to first layer
                    input_dim = module_dict['input_dim']

                    if hidden_dim is not None:
                        layer_list.append(Linear(in_features=input_dim,
                                                out_features=module_dict['hidden_dim'],
                                                bias=True))
                    else:
                        layer_list.append(Linear(in_features=input_dim,
                                                out_features=module_dict['output_dim'],
                                                bias=True))
                        if output_activation:
                            layer_list.append(output_activation)
                        break # Only output layer
                elif layer_num == num_layers - 1:
                    layer_list.append(Linear(in_features=module_dict['hidden_dim'],
                                                out_features=module_dict['output_dim'],
                                                bias=True))
                    if output_activation:
                        layer_list.append(output_activation)
                    break # Dont add hidden activations
                else:
                    layer_list.append(Linear(in_features=module_dict['hidden_dim'],
                                                out_features=module_dict['hidden_dim'],
                                                bias=True))
                # Add hidden activations if specified
                if hidden_activation:
                    layer_list.append(hidden_activation)
                if layer_dropout:
                    layer_list.append(layer_dropout)
            # Initialize weights for all layers
            if layer_kernel_init:
                layer_list = set_layer_init(layer_list, module_dict, init='kernel')
            if layer_bias_init:
                layer_list = set_layer_init(layer_list, module_dict, init='bias')

            # Finally add to ae_module list
            ae_modules[module_name] = layer_list
        self.ae_modules = ModuleDict(ae_modules)

    def forward(self, x):
        """Forward pass through the model."""
        # Stores all module outputs
        ae_module_outputs = {}

        # Pass through encoder
        for j, layer in enumerate(self.ae_modules['encoder']):
            if j == 0:
                ae_module_outputs['z'] = layer(x)
            else:
                ae_module_outputs['z'] = layer(ae_module_outputs['z'])

        # Pass through predictor
        for j, layer in enumerate(self.ae_modules['bg_predictor']):
            if j == 0:
                ae_module_outputs['y1_pred'] = layer(ae_module_outputs['z'])
            else:
                ae_module_outputs['y1_pred'] = layer(ae_module_outputs['y1_pred'])

        # Pass through decoder
        for j, layer in enumerate(self.ae_modules['A_predictor']):
            if j == 0:
                ae_module_outputs['design1_pred'] = layer(ae_module_outputs['z'])
            else:
                ae_module_outputs['design1_pred'] = layer(ae_module_outputs['design1_pred'])

        # Pass through decoder
        for j, layer in enumerate(self.ae_modules['B_predictor']):
            if j == 0:
                ae_module_outputs['design2_pred'] = layer(ae_module_outputs['z'])
            else:
                ae_module_outputs['design2_pred'] = layer(ae_module_outputs['design2_pred'])

        for j, layer in enumerate(self.ae_modules['decoder']):
            if j == 0:
                ae_module_outputs['x_pred'] = layer(ae_module_outputs['z'])
            else:
                ae_module_outputs['x_pred'] = layer(ae_module_outputs['x_pred'])

        return ae_module_outputs

#### Datasets Details

In [3]:
# ---------------------------------------------------------
dataset_loc = 'datasets/perov_bandgaps_AND_h2_prod_rate/nestedae_dataset/perov_bandgaps_all.csv'
# dataset_loc = 'datasets/H2_prod_rate/DFT_data/props_bg_hof.xlsx'
# sheet_name = 'dataset'

#########################################
# FOR AB03 BANDGAP DATASET
#########################################

latent_col_names = []
descriptors = ['A_ATRAD',
               'A_MASS',
               'A_IE',
               'A_EN',
               'A_EA',
               'A_AN',
               'B_ATRAD',
               'B_MASS',
               'B_IE',
               'B_EN',
               'B_EA',
               'B_AN'
]
# descriptors = ['surface_area_m2_g']
target = [
          'direct_bandgap'
          ]
# target = ['ProdRate', 'calc_T_K', 'calc_time_h', 'promoter_w', 'alcohol_percent']
target_A_ion = [
    'Ag_A', 'Al_A', 'As_A', 'Au_A', 'B_A', 'Ba_A', 'Be_A', 'Bi_A', 'Ca_A', 	'Cd_A', 'Co_A', 'Cr_A', 'Cs_A', 'Cu_A',
    'Fe_A', 'Ga_A', 'Ge_A', 'Hf_A', 'Hg_A', 'In_A', 'Ir_A', 'K_A', 'La_A', 'Li_A', 'Mg_A', 'Mn_A', 'Mo_A', 'Na_A', 
    'Nb_A', 'Ni_A', 'Os_A', 'Pb_A', 'Pd_A', 'Pt_A', 'Rb_A', 'Re_A', 'Rh_A', 'Ru_A', 'Sb_A', 'Sc_A', 'Si_A', 'Sn_A',
    'Sr_A', 'Ta_A', 'Te_A', 'Ti_A', 'Tl_A', 'V_A', 'W_A', 'Y_A', 'Zn_A', 'Zr_A'
]

target_B_ion = [
    'Ag_B', 'Al_B', 'As_B', 'Au_B', 'B_B', 'Ba_B', 'Be_B', 'Bi_B', 'Ca_B', 'Cd_B', 'Co_B', 'Cr_B', 'Cs_B', 'Cu_B',
    'Fe_B', 'Ga_B', 'Ge_B', 'Hf_B', 'Hg_B', 'In_B', 'Ir_B', 'K_B', 'La_B', 'Li_B', 'Mg_B', 'Mn_B', 'Mo_B', 'Na_B', 
    'Nb_B', 'Ni_B', 'Os_B', 'Pb_B', 'Pd_B', 'Pt_B', 'Rb_B', 'Re_B', 'Rh_B', 'Ru_B', 'Sb_B', 'Sc_B', 'Si_B', 'Sn_B',
    'Sr_B', 'Ta_B', 'Te_B', 'Ti_B', 'Tl_B', 'V_B', 'W_B', 'Y_B', 'Zn_B', 'Zr_B'
]
# target_prop_method = ['HT',
#                       'Novel',
#                       'PC',
#                       'SG',
#                       'SSR']
# target_prom_method = ['Impreg',
#                       'PD',
#                       'None']
# target_prom_type = ['NiO',
#                     'Pt',
#                     'RuO2',
#                     'None']
# target_sac_agent_presence = ['sac_agent_ohe']

standardize_descs = True

split_strategy = 'kfold'
train_split = 0.9
seeds = [0, 1, 2, 3, 4]
random_state = 42

plot_train_test_dist = False
for_fold = 0

plot_pcc_matrix = False

model_save_dir = 'nestedae_AE1_bandgaps_THEN_h2_prod_rate'
# ---------------------------------------------------------

X_dataframe = pd.read_csv(dataset_loc)[descriptors + latent_col_names]
Y_dataframe = pd.read_csv(dataset_loc)[target + target_A_ion + target_B_ion]

if standardize_descs:
    desc_means = []
    desc_std_devs = []
    for desc in X_dataframe.columns.tolist():
        mean = X_dataframe[desc].mean()
        desc_means.append(mean)
        std_dev = X_dataframe[desc].std()
        desc_std_devs.append(std_dev)
        X_dataframe[desc] = (X_dataframe[desc] - mean) / std_dev
    print('Descriptors standardized.')
else:
    print('Descriptors not standardized.')


print(f'Dataframe Statistics : {X_dataframe.describe()}')

print(f'Dataset columns : \n')
print(X_dataframe.columns)
dataset = np.concatenate((X_dataframe.to_numpy(dtype=np.float32), 
                          Y_dataframe[target].to_numpy(dtype=np.float32),
                          Y_dataframe[target_A_ion].to_numpy(dtype=np.float32),
                          Y_dataframe[target_B_ion].to_numpy(dtype=np.float32)),axis=1)
# dataset = np.concatenate((X_dataframe.to_numpy(dtype=np.float32), 
#                           Y_dataframe[target].to_numpy(dtype=np.float32),
#                           Y_dataframe[target_prop_method].to_numpy(dtype=np.float32),
#                           Y_dataframe[target_prom_method].to_numpy(dtype=np.float32),
#                           Y_dataframe[target_prom_type].to_numpy(dtype=np.float32),
#                           Y_dataframe[target_sac_agent_presence].to_numpy(dtype=np.float32)),
#                           axis=1)
print(dataset.shape)

Descriptors standardized.
Dataframe Statistics :             A_ATRAD        A_MASS          A_IE          A_EN          A_EA  \
count  2.704000e+03  2.704000e+03  2.704000e+03  2.704000e+03  2.704000e+03   
mean   3.468626e-16 -1.366428e-16  6.910974e-16  1.051099e-17 -4.204395e-17   
std    1.000000e+00  1.000000e+00  1.000000e+00  1.000000e+00  1.000000e+00   
min   -1.845083e+00 -1.559214e+00 -2.328331e+00 -1.960749e+00 -1.108566e+00   
25%   -5.039411e-01 -7.788756e-01 -7.491930e-01 -7.479788e-01 -7.928972e-01   
50%   -1.944467e-02 -1.529768e-01  9.949930e-02  2.060673e-01 -2.392820e-01   
75%    2.333361e-01  6.018775e-01  6.240454e-01  7.396864e-01  7.237913e-01   
max    3.884613e+00  1.779861e+00  2.170528e+00  1.812315e+00  2.901779e+00   

               A_AN       B_ATRAD        B_MASS          B_IE          B_EN  \
count  2.704000e+03  2.704000e+03  2.704000e+03  2.704000e+03  2.704000e+03   
mean   1.471538e-16  3.655853e-16  8.351308e-17  6.881412e-16 -4.352206e-18   
st

#### Data Distribution Analysis

In [4]:
if split_strategy == 'kfold':
    train_idxs = []
    test_idxs = []
    kf = KFold(n_splits=int(1/(1 - train_split)), shuffle=True, random_state=random_state)
    for (train_idx, test_idx) in kf.split(dataset):
        train_idxs.append(train_idx)
        test_idxs.append(test_idx)
elif split_strategy == 'random':
    train_idxs = []
    test_idxs = []
    # Train - Test split of indices
    idxs = list(range(X_dataframe.shape[0]))
    for seed in seeds:
        random.seed(seed)
        np.random.seed(seed)
        random.shuffle(idxs)
        train_idxs.append(idxs[:int(train_split*len(idxs))])
        test_idxs.append(idxs[int(train_split*len(idxs)):])
else:
    raise ValueError(f'Unknown split strategy: {split_strategy}')

# Check how do the histograms of train and test distribution match up
if plot_train_test_dist:
    X_dataframe_train = X_dataframe.iloc[train_idxs[for_fold]]
    Y_dataframe_train = Y_dataframe.iloc[train_idxs[for_fold]]
    X_dataframe_test = X_dataframe.iloc[test_idxs[for_fold]]
    Y_dataframe_test = Y_dataframe.iloc[test_idxs[for_fold]]
    for col in X_dataframe.columns:
        plt.figure(figsize=(5, 3))
        sns.histplot(X_dataframe_train[col], kde=True, label='Train', color='blue', stat='probability', binwidth=0.5)
        sns.histplot(X_dataframe_test[col], kde=True, label='Test', color='red', stat='probability', binwidth=0.5)
        plt.title(f'Distribution of {col}')
        plt.xlabel(col)
        plt.ylabel('Density')
        plt.legend()
        plt.show()
    sns.histplot(Y_dataframe_train[target], kde=True, label='Train', color='blue', stat='probability', binwidth=1)
    sns.histplot(Y_dataframe_test[target], kde=True, label='Test', color='red', stat='probability', binwidth=1)

### Pearson Correlation Analysis

In [5]:
choose_descriptors = latent_col_names + descriptors + target
choose_train_idx = 0
if plot_pcc_matrix:
    # Computing the pearson correlation coefficients (ALWAYS USE THE TRAINING DATASET TO AVOID DATA LEAKAGE)
    pcc = np.round(np.corrcoef(x=dataset[train_idxs[choose_train_idx]][:,:len(choose_descriptors)], rowvar=False), 2)
    # Ref : Gryffin, https://online.ucpress.edu/collabra/article/9/1/87615/197169/A-Brief-Note-on-the-Standard-Error-of-the-Pearson
    std_err_pcc = 1/((dataset[train_idxs[choose_train_idx]].shape[0] - 3)**0.5)
    print(f'Standard error in PCC : {std_err_pcc}')
    adj_pcc = (np.abs(pcc) - std_err_pcc)/(1 - std_err_pcc)
    adj_pcc[adj_pcc < 0] = 0  # Ensure no negative values in adjusted PCC
    # adj_pcc = pcc

    # Add correlatin coeff values to the plot
    fig, ax = plt.subplots(figsize=(16, 16))
    pcc_plot = ax.matshow(adj_pcc, cmap='coolwarm', vmin=-1, vmax=1)
    for i in range(adj_pcc.shape[0]):
        for j in range(adj_pcc.shape[1]):
            plt.text(j, i, f'{adj_pcc[i, j]:.2f}', ha='center', va='center', color='black')
    ax.set_yticks(ticks=np.arange(len(choose_descriptors)), labels=choose_descriptors, rotation=0)
    ax.set_xticks(ticks=np.arange(len(choose_descriptors)), labels=choose_descriptors, rotation=90)
    ax.set_yticks(ticks=np.arange(len(choose_descriptors)))
    ax.set_xticks(ticks=np.arange(len(choose_descriptors)))
    # Add the colorbar
    cbar = plt.colorbar(pcc_plot, ax=ax, fraction=0.0455)
    # cbar.set_label('Adjusted Pearson Correlation Coefficient', rotation=270, labelpad=20)
    # plt.title('Adjusted Pearson Correlation Coefficient Matrix')
    plt.tight_layout()
    plt.show()

#### Train the model

In [7]:
##########################
load_ae = True
saved_module_name = 'AE1'
fold_num=0
# Training params
num_epochs = 3000
lr = 0.01
momentum = 0
# Optimization loss params
l2_coeff = 0
l1_coeff = 0.001
weight_samples_target = True
weight_samples_lambda = 1.0
# Model params
num_y1_latents = 1
# Printing params
print_every_n_batches = 100
print_losses = True
debug = False
pred_lam = 1.0
design_lam = 10.0
latent_lam = 1.0
##########################

latent_dim = 8

module_params = {'name':'AE1', 
                    'modules':{

                        'encoder':{
                            'input_dim':12,
                            'output_dim':latent_dim, 
                            'hidden_dim':25, 
                            'hidden_layers':1, 
                            'hidden_activation':None, 
                            'output_activation':None, 
                            'layer_kernel_init':'xavier_normal', 
                            'layer_bias_init':'zeros', 
                            },

                        'bg_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':1,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':torch.nn.ReLU(),
                            'output_activation':torch.nn.ReLU(),
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'},

                        # 'h2_prod_rate_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':None,
                        #     'output_activation':torch.nn.ReLU(),
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        'A_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':52,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':None,
                            'output_activation':None,
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'
                        },

                        'B_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':52,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':None,
                            'output_activation':None,
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'
                        }

                        # 'calc_temp_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':100,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':torch.nn.ReLU(),
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        # 'calc_time_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':100,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':torch.nn.ReLU(),
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        # 'prom_w_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':torch.nn.ReLU(),
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        # 'alc_percent_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':100,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':torch.nn.ReLU(),
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        # 'latents_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':8,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':None,
                        #     'output_activation':None,
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'},

                        # 'prep_method_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':5,
                        #     'hidden_dim':100,
                        #     'hidden_layers':1,
                        #     'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':None,
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'
                        # },

                        # 'prom_method_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':3,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':None,
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'
                        # },

                        # 'prom_type_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':4,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':None,
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'
                        # },

                        # 'sac_agent_predictor':{
                        #     'input_dim':8,
                        #     'output_dim':1,
                        #     # 'hidden_dim':25,
                        #     # 'hidden_layers':1,
                        #     # 'hidden_activation':torch.nn.ReLU(),
                        #     'output_activation':None,
                        #     # 'layer_dropout':torch.nn.Dropout(p=0.5),
                        #     'layer_kernel_init':'xavier_normal',
                        #     'layer_bias_init':'zeros'
                        # }
                        
                    }}

train_total_pred_loss_per_epoch_per_seed = []
train_y_pred_loss_per_epoch_per_seed = []
train_design_pred_loss_per_epoch_per_seed = []
train_latent_pred_loss_per_epoch_per_seed = []

val_y_pred_loss_per_epoch_per_seed = []
val_design_pred_loss_per_epoch_per_seed = []
val_latent_pred_loss_per_epoch_per_seed = []

ind_losses_dict_train = {}
ind_losses_dict_val = {}

for i in range(len(train_idxs)):
    print('\n')
    print(f'Fold {i}')
    print('\n' )
    train_dataset = dataset[train_idxs[i]]
    val_dataset = dataset[test_idxs[i]]
          
    print(f'Train numpy dataset shape : {train_dataset.shape}, Val. numpy dataset shape : {val_dataset.shape}')

    torch_train_dataset = torch.from_numpy(train_dataset).to(dtype=torch.float32)
    train_data_loader = DataLoader(torch_train_dataset, batch_size=train_dataset.shape[0], shuffle=True)

    torch_val_dataset = torch.from_numpy(val_dataset).to(dtype=torch.float32)
    val_data_loader = DataLoader(torch_val_dataset, batch_size=val_dataset.shape[0], shuffle=False)

    # RANDOM USED HERE - Param init
    # Delete previos model
    if i > 0:
        print("Deleting previous model")
        del ae
        ae = None
    random.seed(random_state)
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    if load_ae:
        print('Loaded AE !')
        loaded_state_dict = f'runs/nestedae_{saved_module_name}_bandgaps_THEN_h2_prod_rate/nestedae_fold{fold_num}'
        ae = AE(module_params)
        ae.load_state_dict(torch.load(loaded_state_dict))
    else:
        ae = AE(module_params)
    print(ae)

    random.seed(random_state)
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    sgd = SGD(ae.parameters(), lr=lr, momentum=momentum, weight_decay=l2_coeff)
    adam = torch.optim.Adam(ae.parameters(), lr=lr, weight_decay=l2_coeff)
    x_l1_loss = L1Loss(reduction='mean')

    y1_l1_loss = L1Loss(reduction='mean')

    # design1_loss = L1Loss(reduction='mean')
    # design2_loss = L1Loss(reduction='mean')
    # design3_loss = L1Loss(reduction='mean')
    # design4_loss = L1Loss(reduction='mean')
    design1_loss = torch.nn.CrossEntropyLoss(reduction='mean')
    design2_loss = torch.nn.CrossEntropyLoss(reduction='mean')
    # design8_loss = torch.nn.CrossEntropyLoss(reduction='mean')

    latent_l1_loss = L1Loss(reduction='mean')

    train_total_pred_loss_per_epoch = []

    train_y_pred_loss_per_epoch = []
    train_design_pred_loss_per_epoch = []
    train_latent_pred_loss_per_epoch = []
    train_x_pred_loss_per_epoch = []

    val_y_pred_loss_per_epoch = []
    val_design_pred_loss_per_epoch = []
    val_latent_pred_loss_per_epoch = []
    val_x_pred_loss_per_epoch = []

    for epoch in range(num_epochs):
        train_total_pred_loss_per_batch = 0
        train_y_pred_loss_per_batch = 0
        train_design_pred_loss_per_batch = 0
        train_latent_pred_loss_per_batch = 0
        train_x_pred_loss_per_batch = 0

        # Train Loop
        print(f' --------- Epoch Stats {epoch+1}/{num_epochs} --------- ')
        for batch, data in enumerate(train_data_loader):
            ae.train()
            x = data[:, 0:(len(descriptors) + len(latent_col_names))]
            if len(latent_col_names) > 0:
                latents = data[:, len(descriptors):(len(descriptors) + len(latent_col_names))]

            # Add additional property predictions here ...
            y1 = data[:, len(descriptors) + len(latent_col_names)]

            # Add design predictions here ..
            # For dataset 1
            design1 = data[:, (len(descriptors) + len(latent_col_names) + 1):(len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion))]
            design2 = data[:, (len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion)):(len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion) + len(target_B_ion))]
            # For dataset 2

            # sgd.zero_grad()
            adam.zero_grad()
            ae_out = ae(x)
            y1_pred, z = ae_out['y1_pred'], ae_out['z']
            # For dataset 1
            design1_pred, design2_pred = ae_out['design1_pred'], ae_out['design2_pred']
            # For dataset 2

            if len(latent_col_names) > 0:
                latents_pred = ae_out['latents_pred']

            if weight_samples_target:
                num_nonzero_elements = torch.nonzero(y1).shape[0]
                total_elements = y1.shape[0]
                sample_wts = copy.deepcopy(y1).reshape(-1, 1)
                sample_wts[sample_wts != 0] = total_elements/num_nonzero_elements
                sample_wts[sample_wts == 0] = total_elements/(total_elements - num_nonzero_elements)
                # train_sampler = train_data_loader.sampler
                # shuffled_idxs_train = list(train_sampler)
                train_y1_pred_loss = torch.mean(weight_samples_lambda*sample_wts*torch.abs(y1_pred - y1.reshape(-1, 1)), dim=0, keepdim=False)
                if print_losses: print('train_y1_pred_loss:', train_y1_pred_loss)
                train_y_pred_loss = train_y1_pred_loss
            else:
                train_y1_pred_loss = y1_l1_loss(y1_pred, y1.reshape(-1, 1))
                if print_losses: print('train_y1_pred_loss:', train_y1_pred_loss)
                train_y_pred_loss = train_y1_pred_loss

            # For dataset 1
            train_design1_pred_loss = design1_loss(design1_pred, design1)
            if print_losses: print('train_design1_pred_loss:', train_design1_pred_loss)
            train_design2_pred_loss = design2_loss(design2_pred, design2)
            if print_losses: print('train_design2_pred_loss:', train_design2_pred_loss)
            # For dataset 2

            # For dataset 1
            train_design_pred_loss = train_design1_pred_loss + train_design2_pred_loss
            # For dataset 2
            # train_design_pred_loss = None

            if len(latent_col_names) > 0:
                train_latent_pred_loss = latent_l1_loss(latents_pred, latents)
            else:
                train_latent_pred_loss = torch.tensor(0)
            if print_losses: print('train_latent_pred_loss:', train_latent_pred_loss)
                    
            train_total_pred_loss = pred_lam*train_y_pred_loss + design_lam*train_design_pred_loss + latent_lam*train_latent_pred_loss

            # Get index of max value fro each row of design_pred_idxs
            design1_pred_idxs_train = torch.argmax(torch.softmax(design1_pred, dim=1), dim=1)
            design1_true_idxs_train = torch.argmax(design1, dim=1)
            if debug:
                print(f'First 20 predicted indices train : {design1_pred_idxs_train[:20]}')
                print(f'First 20 true indices train : {design1_true_idxs_train[:20]}')
            # Calculate train accuracy
            design1_train_accuracy = (design1_pred_idxs_train == design1_true_idxs_train).float().mean().item() * 100
            print(f'Train Accuracy : {design1_train_accuracy}')

            design2_pred_idxs_train = torch.argmax(torch.softmax(design2_pred, dim=1), dim=1)
            design2_true_idxs_train = torch.argmax(design2, dim=1)
            if debug:
                print(f'First 20 predicted indices train : {design2_pred_idxs_train[:20]}')
                print(f'First 20 true indices train : {design2_true_idxs_train[:20]}')
            # Calculate train accuracy
            design2_train_accuracy = (design2_pred_idxs_train == design2_true_idxs_train).float().mean().item() * 100
            print(f'Train Accuracy : {design2_train_accuracy}')

            if l1_coeff > 0:
                for name, param in ae.named_parameters():
                    train_total_pred_loss += l1_coeff * param.abs().sum()
            
            # ---  STEP 1  --- : Compute pearson correlation coefficients
            # --- STEP 1.1 --- : with respect to target
            rho_abs_w_target1 = torch.empty(num_y1_latents)
            y1_mean, y1_std = y1.mean(), y1.std()
            for j, l in enumerate(range(num_y1_latents)):
                z_mean, z_std = z[:, l].mean(), z[:, l].std()
                Czy = ((z[:, l] - z_mean) * (y1 - y1_mean)).mean()
                rho_abs_w_target1[j] = abs(Czy / (z_std * y1_std))
                # # Verify correctness of correflation coefficient agsinst torch corrcoef function
                # concat_tensor = cat((z[:, i].reshape(-1, 1), y.reshape(-1, 1)), dim=1)
                # print('Computed using inbuilt function')
                # print(abs(corrcoef(concat_tensor.T)[0, 1]))
            
            # --- STEP 1.2 --- : with respect to other latent variables
            rho_abs_w_latents = []
            for j in range(z.shape[1]):
                z_j_mean, z_j_std = z[:, j].mean(), z[:, j].std()
                for k in range(j+1, z.shape[1]):
                    z_k_mean, z_k_std = z[:, k].mean(), z[:, k].std()
                    Czz = ((z[:, j] - z_j_mean) * (z[:, k] - z_k_mean)).mean()
                    rho_abs_w_latents.append(abs(Czz / (z_j_std * z_k_std)))
            rho_abs_w_latents = torch.stack(rho_abs_w_latents)

            # ---  STEP 2  --- : Compute std err in PCCs and adjust rho_abs_w_target & rho_abs_w_latents
            std_err = 1/math.sqrt(len(y1_pred) - 3) # (Ref : Gryffin)
            if debug : print('std_err:', std_err)

            rho_abs_w_target1_adj = (rho_abs_w_target1 - std_err)/(1 - std_err)
            rho_abs_w_target1_adj[rho_abs_w_target1_adj < 0] = 0
            if debug : print('rho_abs_w_target1_adj :', rho_abs_w_target1_adj)

            rho_abs_w_latents_adj = (rho_abs_w_latents - std_err)/(1 - std_err)
            rho_abs_w_latents_adj[rho_abs_w_latents_adj < 0] = 0
            if debug : print('rho_abs_w_latents_adj :', rho_abs_w_latents_adj)

            # # ---  STEP 3  --- : Atleast one of the PCCs wrt target should be maximized
            lambda_0_1 = torch.mean(1 - rho_abs_w_target1_adj)
            if debug : print('lambda_0_1:', lambda_0_1)

            # # --- Step 4 --- : Favor PCCs with target that are close to 1 or below std_err
            # lambda_1 = torch.mean(torch.pow(torch.sin(math.pi*rho_abs_w_target_adj), 2))
            # if debug : print('lambda_1:', lambda_1)
            
            # --- Step 5 --- : Favor PCCs between latents that arer close to 0
            lambda_1 = torch.mean(torch.pow(torch.sin((math.pi/2)*rho_abs_w_latents_adj), 2))
            if debug : print('lambda_1:', lambda_1)
            
            # --- Step 5 --- : Add all the losses
            # train_total_pred_loss += lambda_0 + lambda_1 + lambda_2
            # train_total_pred_loss += lambda_0_1 + lambda_0_2 + lambda_1
            train_total_pred_loss += lambda_0_1 + lambda_1
            
            train_total_pred_loss.backward()
            # sgd.step()
            adam.step()
            # Store the losses for each batch
            train_total_pred_loss_per_batch += train_total_pred_loss.item()
            train_y_pred_loss_per_batch += train_y_pred_loss.item()
            train_design_pred_loss_per_batch += train_design_pred_loss.item()
            train_latent_pred_loss_per_batch += train_latent_pred_loss.item()
            # # Printing purposes
            # if print_losses:
            #     if batch % print_every_n_batches == 0:
            #         print(f'Batch {batch}/{len(train_data_loader)}, \
            #                Total Loss: {train_total_pred_loss.item():.4f}, \
            #                 Y Pred Loss: {train_y_pred_loss.item():.4f}, \
            #                 X Pred Loss: {train_x_pred_loss.item():.4f}')
        
        ind_losses_dict_train[f'fold_{i}'] = [train_y1_pred_loss.item(), 
                                              train_design1_pred_loss.item(), 
                                              train_design2_pred_loss.item()]

        val_y_pred_loss_per_batch = 0
        val_design_pred_loss_per_batch = 0
        val_latent_pred_loss_per_batch = 0

        # Validation Loop
        for batch, data in enumerate(val_data_loader):
            ae.eval()

            x = data[:, 0:(len(descriptors) + len(latent_col_names))]
            
            if len(latent_col_names) > 0:
                latents = data[:, len(descriptors):(len(descriptors) + len(latent_col_names))]

            # Add additional property predictions here ...
            y1 = data[:, len(descriptors) + len(latent_col_names)]

            # Add design predictions here ..
            # For dataset 1
            design1 = data[:, (len(descriptors) + len(latent_col_names) + 1):(len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion))]
            design2 = data[:, (len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion)):(len(descriptors) + len(latent_col_names) + 1 + len(target_A_ion) + len(target_B_ion))]
            # For dataset 2

            with torch.no_grad():
                ae_out = ae(x)

                y1_pred, z = ae_out['y1_pred'], ae_out['z']
                # For dataset 1
                design1_pred, design2_pred = ae_out['design1_pred'], ae_out['design2_pred']
                # For dataset 2
                if len(latent_col_names) > 0:
                    latents_pred = ae_out['latents_pred']

                val_y1_pred_loss = y1_l1_loss(y1_pred, y1.reshape(-1, 1))
                val_y_pred_loss = val_y1_pred_loss

                val_design1_pred_loss = design1_loss(design1_pred, design1)
                val_design2_pred_loss = design2_loss(design2_pred, design2)
                val_design_pred_loss = val_design1_pred_loss + val_design2_pred_loss

                if len(latent_col_names) > 0:
                    val_latent_pred_loss = latent_l1_loss(latents_pred, latents)
                else:
                    val_latent_pred_loss = torch.tensor(0)

            # Get index of max value fro each row of design_pred_idxs
            design1_pred_idxs_val = torch.argmax(torch.softmax(design1_pred, dim=1), dim=1)
            design1_true_idxs_val = torch.argmax(design1, dim=1)
            val_accuracy = (design1_pred_idxs_val == design1_true_idxs_val).float().mean().item() * 100
            print(f'Val Accuracy : {val_accuracy}')

            design2_pred_idxs_val = torch.argmax(torch.softmax(design2_pred, dim=1), dim=1)
            design2_true_idxs_val = torch.argmax(design2, dim=1)
            val_accuracy = (design2_pred_idxs_val == design2_true_idxs_val).float().mean().item() * 100
            print(f'Val Accuracy : {val_accuracy}')

            val_y_pred_loss_per_batch += val_y_pred_loss.item()
            val_design_pred_loss_per_batch += val_design_pred_loss.item()
            val_latent_pred_loss_per_batch += val_latent_pred_loss.item()

            # Printing purposes
            if print_losses:
                if batch % print_every_n_batches == 0:
                    print(f'Batch {batch}/{len(val_data_loader)}, Y Pred Loss: {val_y_pred_loss.item():.4f}, Design Pred Loss: {val_design_pred_loss.item():.4f}')

        # Store train loss curves
        train_total_pred_loss_per_epoch.append(train_total_pred_loss_per_batch / len(train_data_loader))

        # train_x_pred_loss_per_epoch.append(train_x_pred_loss_per_batch / len(train_data_loader))
        train_y_pred_loss_per_epoch.append(train_y_pred_loss_per_batch / len(train_data_loader))
        train_design_pred_loss_per_epoch.append(train_design_pred_loss_per_batch / len(train_data_loader))
        train_latent_pred_loss_per_epoch.append(train_latent_pred_loss_per_batch / len(train_data_loader))

        val_y_pred_loss_per_epoch.append(val_y_pred_loss_per_batch / len(val_data_loader))
        val_design_pred_loss_per_epoch.append(val_design_pred_loss_per_batch / len(val_data_loader))
        val_latent_pred_loss_per_epoch.append(val_latent_pred_loss_per_batch / len(val_data_loader))

        if print_losses:
            print(f' --------- Epoch Stats {epoch+1}/{num_epochs} --------- ')
            print(f' -- Train -- Total Loss: {train_total_pred_loss_per_epoch[-1]:.4f},\
                                 Y Pred Loss: {train_y_pred_loss_per_epoch[-1]:.4f},\
                                 Design Pred Loss: {train_design_pred_loss_per_epoch[-1]:.4f},\
                                 Latent Pred Loss: {train_latent_pred_loss_per_epoch[-1]:.4f}')
            print(f' --  Val  -- Y Pred Loss: {val_y_pred_loss_per_epoch[-1]:.4f},\
                                 Design Pred Loss: {val_design_pred_loss_per_epoch[-1]:.4f},\
                                 Latent Pred Loss: {val_latent_pred_loss_per_epoch[-1]:.4f}')
            print(f' ------------------------------------------')

    ind_losses_dict_val[f'fold_{i}'] = [val_y1_pred_loss.item(), 
                                        val_design1_pred_loss.item(), 
                                        val_design2_pred_loss.item()]

    train_total_pred_loss_per_epoch_per_seed.append(train_total_pred_loss_per_epoch)
    train_y_pred_loss_per_epoch_per_seed.append(train_y_pred_loss_per_epoch)
    train_design_pred_loss_per_epoch_per_seed.append(train_design_pred_loss_per_epoch)
    train_latent_pred_loss_per_epoch_per_seed.append(train_latent_pred_loss_per_epoch)

    val_y_pred_loss_per_epoch_per_seed.append(val_y_pred_loss_per_epoch)
    val_design_pred_loss_per_epoch_per_seed.append(val_design_pred_loss_per_epoch)
    val_latent_pred_loss_per_epoch_per_seed.append(val_latent_pred_loss_per_epoch)

    # Save the model to the runs directory
    model_save_path = f'runs/{model_save_dir}/nestedae_fold{i}'
    torch.save(ae.state_dict(), model_save_path)
    print(f'Model saved to {model_save_path}')



Fold 0


Train numpy dataset shape : (2433, 117), Val. numpy dataset shape : (271, 117)
Loaded AE !
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
 --> Setting out layer kernel init with xavier normal distribution with gain 1
AE(
  (ae_modules): ModuleDict(
    (encoder): ModuleList(
      (0): Linear(in_features=12, out_features=25, bias=True)
      (1): Linear(in_features=25, out_features=8, bias=True)
    )
    (bg_predictor): ModuleList(
      (0): Linear(in_features=8, out_features=

KeyboardInterrupt: 

### Individual data fold statistics

In [None]:
ind_losses_train = pd.DataFrame.from_dict(ind_losses_dict_train).to_numpy().T
ind_losses_train_sum = np.sum(ind_losses_train, axis=1)
ind_losses_train_means = np.mean(ind_losses_train, axis=0)
ind_losses_train_std_dev = np.std(ind_losses_train, axis=0)
print(ind_losses_train)
print(ind_losses_train_means)
print(ind_losses_train_std_dev)
print("\n")

ind_losses_val = pd.DataFrame.from_dict(ind_losses_dict_val).to_numpy().T
ind_losses_val_sum = np.sum(ind_losses_val, axis=1)
ind_losses_val_means = np.mean(ind_losses_val, axis=0)
ind_losses_val_std_dev = np.std(ind_losses_val, axis=0)
print(ind_losses_val)
print(ind_losses_val_means)
print(ind_losses_val_std_dev)

print('\n')
print(ind_losses_train_sum)

print(f'Fold : {np.argmin(ind_losses_train_sum)}')

### Load the selected trained AE

In [None]:
fold_num = 0
saved_module_name = 'AE1'
latent_dim = 8

module_params = {'name':'AE1', 
                    'modules':{

                        # Encoder for AE1 
                        'encoder':{
                            'input_dim':12,
                            'output_dim':latent_dim, 
                            'hidden_dim':25, 
                            'hidden_layers':1, 
                            'hidden_activation':None, 
                            'output_activation':torch.nn.Tanh(), 
                            'layer_kernel_init':'xavier_normal', 
                            'layer_bias_init':'zeros', 
                            },

                        'bg_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':1,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':torch.nn.ReLU(),
                            'output_activation':torch.nn.ReLU(),
                            # 'layer_dropout':torch.nn.Dropout(p=0.5),
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'},

                        'A_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':52,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':None,
                            'output_activation':None,
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'
                        },

                        'B_predictor':{
                            'input_dim':latent_dim,
                            'output_dim':52,
                            'hidden_dim':25,
                            'hidden_layers':1,
                            'hidden_activation':None,
                            'output_activation':None,
                            'layer_kernel_init':'xavier_normal',
                            'layer_bias_init':'zeros'
                        }
                        
                    }}

loaded_state_dict = f'runs/nestedae_{saved_module_name}_bandgaps_THEN_h2_prod_rate/nestedae_fold{fold_num}'
loaded_ae = AE(module_params)
loaded_ae.load_state_dict(torch.load(loaded_state_dict))
loaded_ae.eval()

### Property predictions for all samples in current dataset

In [None]:
x_torch_all = torch.from_numpy(dataset[:, 0:(len(descriptors + latent_col_names))]).to(dtype=torch.float32)
bandgaps_true = torch.from_numpy(dataset[:, 12])
loaded_ae.eval()
with torch.no_grad():
    ae_out = loaded_ae(x_torch_all)

latents = ae_out['z']
bandgaps_pred = ae_out['y1_pred']

print(bandgaps_pred)

# Scatter plot true and predicted bandgaps
plt.figure(figsize=(10, 10))
plt.scatter(bandgaps_true, bandgaps_pred, alpha=0.5)
plt.xlabel('True Binding Energy')
plt.ylabel('Predicted Binding Energy')
plt.title('True vs Predicted Binding Energy')
plt.plot([bandgaps_true.min(), bandgaps_true.max()], [bandgaps_true.min(), bandgaps_true.max()], 'r--', lw=2)
plt.xlim(bandgaps_true.min(), bandgaps_true.max())
plt.ylim(bandgaps_true.min(), bandgaps_true.max())
plt.grid()
plt.show()

print(np.mean(np.abs(bandgaps_pred.reshape(-1, 1).detach().numpy() - bandgaps_true.reshape(-1, 1).detach().numpy())))

#### Property predictions for samples in the H2 Prod Rate Dataset

In [None]:
x_df = pd.read_csv('datasets/H2_prod_rate/props_from_sa_h2_rate.csv')[descriptors]
print(x_df.shape)
for i, desc in enumerate(x_df.columns.tolist()):
    mean = x_df[desc].mean()
    std = x_df[desc].std()
    x_df[desc] = (x_df[desc] - desc_means[i]) / desc_std_devs[i]

x_torch = torch.from_numpy(x_df.to_numpy(dtype=np.float32)).to(dtype=torch.float32)
with torch.no_grad():
    ae_out = loaded_ae(x_torch)
latents = ae_out['z']
bandgaps = ae_out['y1_pred']
a_design_pred = torch.argmax(torch.softmax(ae_out['design1_pred'], dim=1), dim=1)
b_design_pred = torch.argmax(torch.softmax(ae_out['design2_pred'], dim=1), dim=1)
np.savetxt('datasets/H2_prod_rate/a_design_pred_for_h2_prod_dataset.csv', a_design_pred.detach().numpy(), delimiter=",")
np.savetxt('datasets/H2_prod_rate/b_design_pred_for_h2_prod_dataset.csv', b_design_pred.detach().numpy(), delimiter=",")
np.savetxt('datasets/H2_prod_rate/bandgap_pred_for_h2_prod_dataset.csv', bandgaps.detach().numpy(), delimiter=",")
print(bandgaps)
np.savetxt('datasets/H2_prod_rate/latents_for_h2_prod_dataset.csv', latents.detach().numpy(), delimiter=",")

#### Plot model training results

In [None]:
num_latents = len(latent_col_names)
epochs = np.arange(1, num_epochs + 1)
# Plot mean and standard deviation of losses
train_mean_total_pred_loss = np.mean(np.log(np.array(train_total_pred_loss_per_epoch_per_seed)), axis=0)
train_std_total_pred_loss = np.std(np.log(np.array(train_total_pred_loss_per_epoch_per_seed)), axis=0)

train_mean_y_pred_loss = np.mean(np.log(np.array(train_y_pred_loss_per_epoch_per_seed)), axis=0)
train_std_y_pred_loss = np.std(np.log(np.array(train_y_pred_loss_per_epoch_per_seed)), axis=0)

train_mean_design_pred_loss = np.mean(np.log(np.array(train_design_pred_loss_per_epoch_per_seed)), axis=0)
train_std_design_pred_loss = np.std(np.log(np.array(train_design_pred_loss_per_epoch_per_seed)), axis=0)

if num_latents > 0:
    train_mean_latent_pred_loss = np.mean(np.log(np.array(train_latent_pred_loss_per_epoch_per_seed)), axis=0)
    train_std_latent_pred_loss = np.std(np.log(np.array(train_latent_pred_loss_per_epoch_per_seed)), axis=0)
else:
    train_mean_latent_pred_loss = np.mean(train_latent_pred_loss_per_epoch_per_seed, axis=0)
    train_std_latent_pred_loss = np.std(train_latent_pred_loss_per_epoch_per_seed, axis=0)

val_mean_y_pred_loss = np.mean(np.log(np.array(val_y_pred_loss_per_epoch_per_seed)), axis=0)
val_std_y_pred_loss = np.std(np.log(np.array(val_y_pred_loss_per_epoch_per_seed)), axis=0)

val_mean_design_pred_loss = np.mean(np.log(np.array(val_design_pred_loss_per_epoch_per_seed)), axis=0)
val_std_design_pred_loss = np.std(np.log(np.array(val_design_pred_loss_per_epoch_per_seed)), axis=0)

if num_latents > 0:
    val_mean_latent_pred_loss = np.mean(np.log(np.array(val_latent_pred_loss_per_epoch_per_seed)), axis=0)
    val_std_latent_pred_loss = np.std(np.log(np.array(val_latent_pred_loss_per_epoch_per_seed)), axis=0)
else:
    val_mean_latent_pred_loss = np.mean(val_latent_pred_loss_per_epoch_per_seed, axis=0)
    val_std_latent_pred_loss = np.std(val_latent_pred_loss_per_epoch_per_seed, axis=0)

# Plotting
plt.figure(figsize=(6, 12))
plt.subplot(7, 1, 1)
plt.plot(epochs, train_mean_total_pred_loss, 
         label=f'{round(np.mean(np.array(train_total_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(train_total_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', 
        color='black')
plt.fill_between(epochs, train_mean_total_pred_loss - train_std_total_pred_loss, 
                 train_mean_total_pred_loss + train_std_total_pred_loss, color='black', alpha=0.2)
plt.title('Train Total Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 2)
plt.plot(epochs, train_mean_y_pred_loss, 
         label=f'{round(np.mean(np.array(train_y_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(train_y_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', 
         color='blue')
plt.fill_between(epochs, train_mean_y_pred_loss - train_std_y_pred_loss,
                train_mean_y_pred_loss + train_std_y_pred_loss, color='blue', alpha=0.2)
plt.title('Train Y Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 3)
plt.plot(epochs, train_mean_design_pred_loss, 
         label=f'{round(np.mean(np.array(train_design_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(train_design_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', 
         color='green')
plt.fill_between(epochs, train_mean_design_pred_loss - train_std_design_pred_loss,
                 train_mean_design_pred_loss + train_std_design_pred_loss, color='green', alpha=0.2)
plt.title('Train Design Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 4)
plt.plot(epochs, train_mean_latent_pred_loss, 
         label=f'{round(np.mean(np.array(train_latent_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(train_latent_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', 
         color='darkorange')
plt.fill_between(epochs, train_mean_latent_pred_loss - train_std_latent_pred_loss,
                  train_mean_latent_pred_loss + train_std_latent_pred_loss, color='darkorange', alpha=0.2)
plt.title('Train Latent Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 5)
plt.plot(epochs, val_mean_y_pred_loss,
         label=f'{round(np.mean(np.array(val_y_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(val_y_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', color='lightskyblue')
plt.fill_between(epochs, val_mean_y_pred_loss - val_std_y_pred_loss,
                  val_mean_y_pred_loss + val_std_y_pred_loss, color='lightskyblue', alpha=0.2)
plt.title('Val Y Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 6)
plt.plot(epochs, val_mean_design_pred_loss, 
         label=f'{round(np.mean(np.array(val_design_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-' +
               f'{round(np.std(np.array(val_design_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', color='yellowgreen')
plt.fill_between(epochs, val_mean_design_pred_loss - val_std_design_pred_loss, val_mean_design_pred_loss + val_std_design_pred_loss, color='yellowgreen', alpha=0.2)
plt.title('Val Design Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()

plt.subplot(7, 1, 7)
plt.plot(epochs, val_mean_latent_pred_loss, 
         label=f'{round(np.mean(np.array(val_latent_pred_loss_per_epoch_per_seed)[:, -1]), 2)} +/-'+
               f'{round(np.std(np.array(val_latent_pred_loss_per_epoch_per_seed)[:, -1]), 2)}', color='orange')
plt.fill_between(epochs, val_mean_latent_pred_loss - val_std_latent_pred_loss, val_mean_latent_pred_loss + val_std_latent_pred_loss, color='orange', alpha=0.2)
plt.title('Val Latent Pred Loss')
plt.xlabel('Epochs')
plt.ylabel('log(loss)')
plt.legend()
plt.tight_layout()

#### Model feature importance analysis

In [None]:
dataset_df = pd.DataFrame(dataset[:, :len(latent_col_names + descriptors)], columns=latent_col_names + descriptors)

def convert_to_tensor(obs):
    obs_tensor = torch.from_numpy(obs.values).to(dtype=torch.float32)
    loaded_ae.eval()
    with torch.no_grad():
        ae_out = loaded_ae(obs_tensor)
    return ae_out['z'][:, 5]

explainer = Exact(convert_to_tensor, dataset_df)
shap_values = explainer(dataset_df)

shap.plots.bar(shap_values, max_display=15)

