In [None]:
import os
import torch
import pickle
from Utilities import check_valid
from XAI import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from tdc.benchmark_group import admet_group
import os
from torch import stack, tensor, Generator, cat, float32, nonzero, set_float32_matmul_precision
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_float32_matmul_precision('high')
set_global_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_float32_matmul_precision('high')
set_global_seed(42)

group = admet_group(path = '../data_tdc/')
names = group.dataset_names
data = {}
seeds = [1,2,3,4,5]
models = [
    'charges_20L_wide_def_2e-5_16p', 
    'nmr_20L_wide_def_2e-5_16p', 
    'fukui_n_20L_wide_def_2e-5_16p', 
    'fukui_e_20L_wide_def_2e-5_16p', 
    'homo-lumo_20L_wide_def_2e-5_16p', 
    'qm_all_20L_wide_def_2e-5_16p', 
    'masking_20L_wide_def_2e-5_16p', 
    'scratch_20L_wide_def_2e-5_16p'
]


for name in names:
    
    data[name] = {}
    
    for seed in [1]:
        
        benchmark = group.get(name)
        name = benchmark['name']
        train_val, test = benchmark['train_val'], benchmark['test']
        train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)
        
        try:
            smiles = test['Drug'].sample(50).values.tolist()
        
        except:
            smiles = test['Drug'].values.tolist()
            
        smiles = [s for s in smiles if check_valid(s)]
        
        for model in models:
            
            dlt = get_dtl_from_smiles(smiles)
            ckpt_root = f'./TDC_checkpoints/{model}/{name}_{seed}/'
            buffer = []
            for batch in dlt:
                sensitivities = get_sensitivities_per_topdistance(ckpt_root, batch)
                buffer.append(sensitivities)
            
            data[name][model] = buffer    

In [None]:
with open('./oversquashing_data.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
import pickle
with open('./oversquashing_data.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
group = admet_group(path = '../data_tdc/')
names = group.dataset_names
seeds = [1,2,3,4,5]
models = [
          'qm_all_20L_wide_def_2e-5_16p', 
          'nmr_20L_wide_def_2e-5_16p', 
          'charges_20L_wide_def_2e-5_16p', 
          'fukui_n_20L_wide_def_2e-5_16p', 
          'fukui_e_20L_wide_def_2e-5_16p', 
          'homo-lumo_20L_wide_def_2e-5_16p', 
          'masking_20L_wide_def_2e-5_16p', 
          'scratch_20L_wide_def_2e-5_16p'
         ]

def normalize_list(dd):
    
    dd = [d-np.min(dd) for d in dd]
    dd = [d/np.max(dd) for d in dd]
    return dd

tops = 7
j = 10
model_pos_shift = 0
fig = plt.figure(figsize=(40,10))
df = pd.DataFrame(np.array([[], [], []]).T, columns = ['neighbour', 'model', 'sens'])
neighbs_list = []
sens_list = []
models_list = []

for k in models:
    
    cumulation = []
    
    for name in names:
        try:
            temp = [s for s in data[name][k] if (~np.isnan(s[:j]).any())]
            cumulation+=temp
        except:
            pass
    
    cumulation = [normalize_list(c) for c in cumulation]
    
    subname = k.split('_')[0]
    if subname == 'fukui':
        subname = k.split('_')[0] + '_' + k.split('_')[1]
    elif subname == 'qm':
        subname = 'all atomic qm'
    else:
        pass
    
    for i in range(2,tops):
        
        subsel = [c[:i] for c in cumulation if len(c)>=i]
        subsel = np.array(subsel)
        neighbs_list += [i-1]*len(subsel)
        sens_list += subsel[:,i-1].tolist()
        models_list += [subname]*len(subsel)
        plt.boxplot(subsel[:,i-1], positions = [i+model_pos_shift], widths=(.05) ,sym = '', whis = (15, 85))
    
    model_pos_shift += .1
    
plt.xlabel("n-th neighbour")
plt.ylabel(r"$\langle|\frac{\delta h_i}{\delta x_j^n}|\rangle_{ij}$")
plt.title(subname)
xticks = plt.gca().get_xticks()

In [None]:
df['neighbour'] = neighbs_list
df['model'] = models_list
df['sens'] = sens_list

In [None]:
def change(x):
    if x == 'all atomic qm':
        return 'all'
    else:
        return x
    
df['model'] = df['model'].apply(lambda x: change(x))

In [None]:
import seaborn as sns

fig = plt.figure(figsize=(15,5))

args = {
    'whis': (15, 85),
}

sns.boxplot(data=df, x="neighbour", y="sens", hue="model", gap = .5, showfliers=False, fill = False, **args)
plt.xlabel(r'$k^{th}$-neighbour', fontsize = 20)
plt.ylabel(r'$\mathcal{S}_k$', fontsize = 20)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
plt.legend(fontsize = 16)
plt.tight_layout()
plt.savefig('./oversquashing.pdf')