In [None]:
import torch
import pickle
from XAI import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from tdc.benchmark_group import admet_group
import os
from torch import stack, tensor, Generator, cat, float32, nonzero, set_float32_matmul_precision

set_float32_matmul_precision('high')
set_global_seed(42)

group = admet_group(path = '../data_tdc/')
names = group.dataset_names
data = {}
seeds = [1,2,3,4,5]
models = [
    'scratch_20L_wide_def_2e-5_16p', 
    'charges_20L_wide_def_2e-5_16p', 
    'nmr_20L_wide_def_2e-5_16p', 
    'fukui_n_20L_wide_def_2e-5_16p', 
    'fukui_e_20L_wide_def_2e-5_16p', 
    'homo-lumo_20L_wide_def_2e-5_16p', 
    'qm_all_20L_wide_def_2e-5_16p', 
    'masking_20L_wide_def_2e-5_16p'
]

for name in tqdm(names):
    
    for seed in [1, 2, 3, 4, 5]:
        
        for model in models:
            
            ckpt_root = f'./TDC_checkpoints/{model}/{name}_{seed}/'
            ckpts = [a for a in os.listdir(ckpt_root) if a.startswith('epoch')]
            ckpt_path = ckpt_root + ckpts[0]
            benchmark = group.get(name)
            name = benchmark['name']
            train_val, test = benchmark['train_val'], benchmark['test']
            train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)
            smiles = test['Drug'].values.tolist()
            zetas, etas, number_of_laplacians_, which_laplacians = Laplacian_Rollout_analysis(ckpt_path, smiles, is_prepared_as_packed_chython=False, device='cuda')
            data[f'{model}_{name}_{seed}'] = {
                'zetas': zetas,
                'etas': etas,
                'number_of_laplacians_': number_of_laplacians_,
                'which_laplacians': which_laplacians
            }  

In [None]:
with open('./spectra_data_total.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
import pickle 

with open('./spectra_data_total.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
avg_zetas_ = {}

models = [
    'scratch_20L_wide_def_2e-5_16p', 
    'charges_20L_wide_def_2e-5_16p', 
    'nmr_20L_wide_def_2e-5_16p', 
    'fukui_n_20L_wide_def_2e-5_16p', 
    'fukui_e_20L_wide_def_2e-5_16p', 
    'homo-lumo_20L_wide_def_2e-5_16p', 
    'qm_all_20L_wide_def_2e-5_16p', 
    'masking_20L_wide_def_2e-5_16p'
]

for name in tqdm(names):
    
    avg_zetas_[name] = {}
    
    for model in models:
        
        subname = model.split('_')[0]
        
        if subname == 'fukui':
            subname = model.split('_')[0] + '_' + model.split('_')[1]
        
        elif subname == 'qm':
            subname = 'all'
        
        else:
            pass
        
        avg_zetas_[name][subname] = np.mean([np.mean(data[f'{model}_{name}_{seed}']['zetas']) for seed in [1,2,3,4,5]]) 

In [None]:
with open('./avg_zetas.pkl', 'wb') as f:
    pickle.dump(avg_zetas_, f)

In [None]:
with open('./avg_zetas.pkl', 'rb') as f:
    avg_zetas = pickle.load(f)

In [None]:
def invert(original_dict):
    inverted_dict = {}
    for k, inner_dict in original_dict.items():
        for v, c in inner_dict.items():
            if v not in inverted_dict:
                inverted_dict[v] = {}
            inverted_dict[v][k] = c
    return inverted_dict

avg_zetas = invert(avg_zetas)

In [None]:
reorganised_avgzs = {}

for k,v in avg_zetas.items():
    reorganised_avgzs[k] = []
    for val in v.values():
        reorganised_avgzs[k]+=[val]

In [None]:
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

plt.figure(figsize=(15,5))
data = {}

models_list = ['all', 'charges', 'nmr', 'fukui_n' ,'fukui_e', 'homo-lumo', 'masking', 'scratch']

for m in models_list:
    data[m] = reorganised_avgzs[m]

df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in data.items()]))

df_melted = df.melt(var_name='Model', value_name=r'$\zeta$')

sns.swarmplot(x='Model', y=r'$\zeta$', data=df_melted, hue = 'Model', legend = 'brief')
plt.xticks(rotation = 45, fontsize = 16)
plt.legend(fontsize = '16')
plt.yticks(fontsize = 16)
plt.xlabel('Models by pretraining method', fontdict={'fontsize': 20})
plt.ylabel(r'$\langle \zeta \rangle_{\rm{test}}$', fontdict={'fontsize': 20})
plt.tight_layout()
plt.savefig('./swarm_zeta.pdf')