In [None]:
import torch
import pickle
from XAI import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from tdc.benchmark_group import admet_group
import os
from torch import stack, tensor, Generator, cat, float32, nonzero, set_float32_matmul_precision
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import pickle
import random
import seaborn as sns

set_float32_matmul_precision('high')
set_global_seed(42)

def get_mode(x, bins):
    hist, bins = np.histogram(x, bins)
    return bins[np.argmax(hist)]

def make_plt(data, label, what = 'mode', confidence = 0.67, bins = 50, ax = None):
    
    # Calculate mean and standard deviation along dimension 0
    if what == 'mean':
        mean_values = np.mean(data, axis = 0)
    elif what == 'mode':
        mean_values = np.apply_along_axis(lambda a: get_mode(a, bins = bins), 0, data)#np.mean(data, axis=0)
    else:
        raise(ValueError, 'use what = mode or mean')
        
    percentiles = np.percentile(data, [(1 - confidence) / 2 * 100, (1 + confidence) / 2 * 100], axis=0)
    lower_bound = percentiles[0]
    upper_bound = percentiles[1]
    if ax is None:
        plt.plot(mean_values[1:-1], label=label)
        plt.fill_between(range(data.shape[1]-2), lower_bound[1:-1], upper_bound[1:-1], alpha=0.1)
    else:
        ax.plot(mean_values[1:-1], label=label)
        ax.fill_between(range(data.shape[1]-2), lower_bound[1:-1], upper_bound[1:-1], alpha=0.1)
        

In [None]:
group = admet_group(path = '../data_tdc/')
names = group.dataset_names
data = {}
seeds = [1]
models = [
    'charges_20L_wide_def_2e-5_16p', 
    'nmr_20L_wide_def_2e-5_16p', 
    'fukui_n_20L_wide_def_2e-5_16p', 
    'fukui_e_20L_wide_def_2e-5_16p', 
    'homo-lumo_20L_wide_def_2e-5_16p', 
    'qm_all_20L_wide_def_2e-5_16p', 
    'masking_20L_wide_def_2e-5_16p', 
    'scratch_20L_wide_def_2e-5_16p'
]

ranks = {}
seed = 1

for name in names:
    
    benchmark = group.get(name) 
    name = benchmark['name']
    train_val, test = benchmark['train_val'], benchmark['test']
    train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)
    try:
        smiles=test['Drug'].sample(100).values.tolist()
    except:
        smiles=test['Drug'].values.tolist()
    
    ranks[name] = {}
    
    for model in tqdm(models):

        subname = model.split('_')[0]

        if subname == 'fukui':
            subname = model.split('_')[0] + '_' + model.split('_')[1]

        elif subname == 'qm':
            subname = 'all'

        else:
            pass

        r = []
    
        for seed in [1]:

            print(name, seed, ':\n')

            ckpts = os.listdir(f'./TDC_checkpoints/{model}/{name}_{seed}/')
            ckpt = [c for c in ckpts if c.startswith('epoch')][0]

            checkpoint_path = f'./TDC_checkpoints/{model}/{name}_{seed}/{ckpt}'
            print(checkpoint_path)
            loaded_path_hyper_dict = torch.load(checkpoint_path)['hyper_parameters']

            model_ = GT(
                checkpoint_path = checkpoint_path,
                **loaded_path_hyper_dict
            )

            model_, w = transfer_weights(checkpoint_path, model_, device = 'cuda')

            model_.eval()
            model_.freeze()

            for n, param in model_.named_parameters():
                param.requires_grad = False

            predictions = {}

            dlt = get_iterator(smiles, is_prepared_as_packed_chython=False)

            for batch in tqdm(dlt):
                r.append(get_rank_residuals(model_, batch))

        ranks[name][subname] = r 

In [None]:
with open('./ranks_tol.pkl', 'wb') as f:
    pickle.dump(ranks, f)

In [None]:
def invert(original_dict):
    inverted_dict = {}
    for k, inner_dict in original_dict.items():
        for v, c in inner_dict.items():
            if v not in inverted_dict:
                inverted_dict[v] = {}
            inverted_dict[v][k] = c
    return inverted_dict

In [None]:
inverted_dict = invert(ranks)
ranks = inverted_dict

In [None]:
with open('./ranks_.pkl', 'rb') as f:
    ranks = invert(pickle.load(f))
    
fig = plt.figure(figsize=(10,5))
    
def make_plt(data, label, what = 'mode', confidence = 0.67, bins = 50, ax = None):
    
    if what == 'mean':
        mean_values = np.mean(data, axis = 0)
    elif what == 'mode':
        mean_values = np.apply_along_axis(lambda a: get_mode(a, bins = bins), 0, data)#np.mean(data, axis=0)
    else:
        raise(ValueError, 'use what = mode or mean')
        
    percentiles = np.percentile(data, [(1 - confidence) / 2 * 100, (1 + confidence) / 2 * 100], axis=0)
    lower_bound = percentiles[0]
    upper_bound = percentiles[1]
    
    if ax is None:
        plt.plot(mean_values[1:-1], label=label)
        plt.fill_between(range(data.shape[1]-2), lower_bound[1:-1], upper_bound[1:-1], alpha=0.1)
    
    else:
        ax.plot(mean_values[1:-1], label=label)
        ax.fill_between(range(data.shape[1]-2), lower_bound[1:-1], upper_bound[1:-1], alpha=0.1)
        
keys = list(ranks.keys())

df = pd.DataFrame(np.array([[],[],[]]).T, columns = ['model', 'value', 'layer'])

model_list = []
value_list = []
layer_list = []

chiavi = ['all', 'nmr', 'charges', 'fukui_e', 'fukui_n', 'homo-lumo', 'masking', 'scratch']
keys = chiavi

subsample_index = random.sample(range(0, 2100), 500)

for i, c in enumerate(chiavi):
    
    if c in chiavi:
    
        trajs = ranks[keys[i]]
        ts = []
        for k, v in trajs.items():
            ts+=v
            
        ts = np.array(ts)
        
        for h in range(0,len(ts)):
            for g in range(0, len(ts[0])):
                value_list += ts[subsample_index,g].tolist()
                model_list += [keys[i]]*len(subsample_index)
                layer_list += [g]*len(subsample_index)
                
        
        make_plt(np.array(ts), label = f'{keys[i]}', confidence = .67) 
        
plt.tight_layout()
plt.legend()

In [None]:
df['model'] = model_list
df['value'] = value_list
df['layer'] = layer_list

In [None]:
plt.figure(figsize = (25,8))
sns.boxplot(data=df, x="layer", y="value", hue="model", whis = (15, 85), fill = False, showfliers = False, gap = 0.2)
plt.xticks(np.arange(0,19), np.arange(1,20),fontsize = 20)
plt.yticks(fontsize = 20)
plt.legend(fontsize = 20)
plt.xlabel(r'$k^{th}$-layer', fontsize = 24)
plt.ylabel(r'$\rho_L$', fontsize = 24)
plt.ylim(0,1.4)
plt.tight_layout()