In [None]:
from XAI import *
import pickle
from Utilities import set_global_seed
import random
import torch
from tdc.benchmark_group import admet_group
from Utilities import *
import numpy as np
from sklearn.cross_decomposition import CCA, PLSRegression, PLSCanonical, PLSSVD
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler


data_root_atomic = '../your_data_root/'
data_root_gap = '../your_data_root/'

with open(f'{data_root_atomic}/structures.pkl', 'rb') as f:
    mols = pickle.load(f)

with open(f'{data_root_atomic}/charges.pkl', 'rb') as f:
    charges = pickle.load(f)

with open(f'{data_root_atomic}/nmr.pkl', 'rb') as f:
    nmrs = pickle.load(f)
    
with open(f'{data_root_atomic}/fukui_n.pkl', 'rb') as f:
    fkn = pickle.load(f)
    
with open(f'{data_root_atomic}/fukui_e.pkl', 'rb') as f:
    fke = pickle.load(f)
    
with open(f'{data_root_gap}/structures.pkl', 'rb') as f:
    mols_gap = pickle.load(f)

with open(f'{data_root_gap}/hlgaps.pkl', 'rb') as f:
    gaps = pickle.load(f)

In [None]:
set_global_seed(42)

group = admet_group(path = '../data_tdc/')
names = group.dataset_names
subsample_index = random.sample(range(0, len(mols)), 5000)
subsample_index_gap = random.sample(range(0, len(mols_gap)), 5000)
mols = [mols[i] for i in subsample_index]
mols_gap = [mols_gap[i] for i in subsample_index_gap]
gaps = torch.tensor([gaps[i] for i in subsample_index_gap])
charges = [torch.tensor(charges[i]) for i in subsample_index]
nmrs = [torch.tensor(nmrs[i]) for i in subsample_index]
fkn = [torch.tensor(fkn[i]) for i in subsample_index]
fke = [torch.tensor(fke[i]) for i in subsample_index]

In [None]:
proprietas = {
              'charges': charges, 
              'nmr': nmrs, 
              'fukui_e': fke, 
              'fukui_n': fkn,
              'homo-lumo': gaps
             }

In [None]:
models = [
    'scratch_20L_wide_def_2e-5_16p', 
    'homo-lumo_20L_wide_def_2e-5_16p', 
    'masking_20L_wide_def_2e-5_16p', 
    'charges_20L_wide_def_2e-5_16p', 
    'nmr_20L_wide_def_2e-5_16p', 
    'fukui_n_20L_wide_def_2e-5_16p', 
    'fukui_e_20L_wide_def_2e-5_16p',  
    'qm_all_20L_wide_def_2e-5_16p'
]


In [None]:
ccas = {}

for name in names:
    ccas[name] = {}
    
    for k in tqdm(range(len(models))):
        ccas[name][models[k][:-22]] = {}
        
        for chiave in proprietas.keys():
                
                if chiave == 'homo-lumo':
                    proprieta = proprietas[chiave]

                    batch_size = 1
                    data = TensorDataset(MoleculeDataset(mols_gap, unpack=True), gaps)
                    dlt = DataLoader(data, collate_fn = chained_collate(collate_molecules, torch.stack), shuffle=False, batch_size=batch_size, num_workers=16)

                    for model in [models[k]]:

                        ckpts = os.listdir(f'./TDC_checkpoints/{model}/{name}_1/')
                        ckpt = [c for c in ckpts if c.startswith('epoch')][0]

                        checkpoint_path = f'./TDC_checkpoints/{model}/{name}_1/{ckpt}'

                        loaded_path_hyper_dict = torch.load(checkpoint_path)['hyper_parameters']

                        model = GT(
                            checkpoint_path = checkpoint_path,
                            **loaded_path_hyper_dict
                        )

                        model, w = transfer_weights(checkpoint_path, model, device = 'cuda')

                        model.eval()
                        model.freeze()

                        for n, param in model.named_parameters():
                            param.requires_grad = False

                        tokens = ()
                        latents = ()
                        properties = ()

                        for batch in dlt:
                            mol, label = batch
                            result, a_, x, xo, L = get_nth_layer(model, batch, device = 'cuda', x_ = True)
                            tokens += (mol.atoms,)
                            latents += (x,)
                            properties += (label,)
                    
                    props = torch.cat(properties, dim = 0).cpu().numpy()
                    lats = torch.cat(latents, dim = 1).cpu().numpy()
                    toks = torch.cat(tokens, dim = 1).cpu().numpy()
                    toks = toks[0]
                    
                    scaler_x = StandardScaler()
                    scaler_y = StandardScaler()
                    X_scaled = scaler_x.fit_transform(lats[0][toks==1])
                    y_scaled = scaler_y.fit_transform(props.reshape(-1, 1))
                    
                    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.5, random_state=42)

                    ridge = ElasticNet(alpha=.1, l1_ratio=0.5)
                    ridge.fit(X_train, y_train)

                    y_pred = ridge.predict(X_test)

                    canonical_correlation = r2_score(y_test.reshape(-1), y_pred.reshape(-1))
                    print(f'{name} {models[k][:-22]} {chiave} CC: {canonical_correlation}')
                    
                    ccas[name][models[k][:-22]][chiave] = canonical_correlation
                    
                    
                else:
                    proprieta = proprietas[chiave]

                    batch_size = 1
                    data = TensorDataset(MoleculeDataset(mols, unpack=True), SizedList(proprieta))
                    dlt = DataLoader(data, collate_fn = chained_collate(collate_molecules, torch.cat), shuffle=False, batch_size=batch_size, num_workers=16)

                    for model in [models[k]]:

                        ckpts = os.listdir(f'./TDC_checkpoints/{model}/{name}_1/'
                        ckpt = [c for c in ckpts if c.startswith('epoch')][0]

                        checkpoint_path = f'./TDC_checkpoints/{model}/{name}_1/{ckpt}'
                        
                        loaded_path_hyper_dict = torch.load(checkpoint_path)['hyper_parameters']

                        model = GT(
                            checkpoint_path = checkpoint_path,
                            **loaded_path_hyper_dict
                        )

                        model, w = transfer_weights(checkpoint_path, model, device = 'cuda')

                        model.eval()
                        model.freeze()

                        for n, param in model.named_parameters():
                            param.requires_grad = False

                        tokens = ()
                        latents = ()
                        properties = ()

                        for batch in dlt:
                            mol, label = batch
                            result, a_, x, xo, L = get_nth_layer(model, batch, n = 20, device = 'cuda', x_ = True)
                            tokens += (mol.atoms,)
                            latents += (x,)
                            properties += (label,)

                    props = torch.cat(properties, dim = 0).cpu().numpy()
                    lats = torch.cat(latents, dim = 1).cpu().numpy()
                    toks = torch.cat(tokens, dim = 1).cpu().numpy()
                    toks = toks[0]

                    scaler_x = StandardScaler()
                    scaler_y = StandardScaler()
                    X_scaled = scaler_x.fit_transform(lats[0][toks!=1])
                    toks = toks[toks!=1]
                    y_scaled = scaler_y.fit_transform(props.reshape(-1, 1))
                    
                    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.5, random_state=42)

                    ridge = ElasticNet(alpha=.1, l1_ratio=0.5)
                    ridge.fit(X_train, y_train)

                    y_pred = ridge.predict(X_test)

                    canonical_correlation = r2_score(y_test.reshape(-1), y_pred.reshape(-1))
                    
                    print(f'{name} {models[k][:-22]} {chiave} CC: {canonical_correlation}')
                    ccas[name][models[k][:-22]][chiave] = canonical_correlation     

In [None]:
chiavi_modelli = list(ccas['caco2_wang'].keys())
chiavi_proprieta = list(ccas['caco2_wang']['scratch'].keys())

tbplotted = {}

for m in chiavi_modelli:
    
    tbplotted[m] = {}
    
    for p in chiavi_proprieta:
        
        tmp = []
        
        for n in names:
            
            tmp.append(ccas[n][m][p])
        
        tbplotted[m][p] = [np.mean(tmp), np.std(tmp)]

In [None]:
tbplotted['all'] = tbplotted['qm_all']

In [None]:
def check(x):
    if check == 'qm_all':
        return 'all'
    else:
        return x

data = tbplotted
means = pd.DataFrame({check(key1): {check(key2): values[0] for key2, values in subdict.items()} for key1, subdict in data.items()})
stds = pd.DataFrame({check(key1): {check(key2): values[1] for key2, values in subdict.items()} for key1, subdict in data.items()})

means = means.round(2)
stds = stds.round(2)

annotations = means.astype(str) + " ± " + stds.astype(str)

x_order = chiavi_proprieta + ['all'] + ['scratch']
y_order = chiavi_proprieta

means = means.reindex(index=y_order, columns=x_order)
stds = stds.reindex(index=y_order, columns=x_order)
annotations = annotations.reindex(index=y_order, columns=x_order)

plt.figure(figsize=(12, 6))
sns.heatmap(means, annot=annotations, fmt="", linewidths=.5, cmap="magma")

plt.xticks(rotation = 45)
plt.yticks(rotation = 45)

plt.xlabel('Models', fontsize = 16)
plt.ylabel('Pretraining property', fontsize = 16)
plt.tight_layout()
plt.savefig('./permanence_notitle.pdf')
plt.show()

In [None]:
proprieta = proprietas['nmr']

batch_size = 1
data = TensorDataset(MoleculeDataset(mols, unpack=True), SizedList(proprieta))
dlt = DataLoader(data, collate_fn = chained_collate(collate_molecules, torch.cat), shuffle=False, batch_size=batch_size, num_workers=16)

model = models[0]
name = names[0]

ckpts = os.listdir(f'./TDC_checkpoints/{model}/{name}_1/')
ckpt = [c for c in ckpts if c.startswith('epoch')][0]

checkpoint_path = f'./TDC_checkpoints/{model}/{name}_1/{ckpt}'

loaded_path_hyper_dict = torch.load(checkpoint_path)['hyper_parameters']

model = GT(
    checkpoint_path = checkpoint_path,
    **loaded_path_hyper_dict
)

model, w = transfer_weights(checkpoint_path, model, device = 'cuda')

model.eval()
model.freeze()

for n, param in model.named_parameters():
    param.requires_grad = False

tokens = ()
latents = ()
properties = ()

for batch in tqdm(dlt):
    mol, label = batch
    result, a_, x, xo, L = get_nth_layer(model, batch, n = 20, device = 'cuda', x_ = True)
    tokens += (mol.atoms,)
    latents += (x,)
    properties += (label,)

In [None]:
props = torch.cat(properties, dim = 0).cpu().numpy()
lats = torch.cat(latents, dim = 1).cpu().numpy()
toks = torch.cat(tokens, dim = 1).cpu().numpy()
toks = toks[0]

np.random.seed(42)
data = lats[0][toks!=1]

scaler = StandardScaler()
data_standardized = data

pca = PCA(n_components=2)
principal_components = pca.fit_transform(data_standardized)

toks = toks[toks!=1]
color = props

colorz = np.tanh((color - np.mean(color))/np.std(color))

plt.figure(figsize=(8, 6))
plt.scatter(principal_components[:, 0], principal_components[:, 1], c=colorz, edgecolor='k', s=50)
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('PCA of Example Data')
plt.grid()
plt.show()

explained_variance = pca.explained_variance_ratio_
print(f'Explained variance by component: {explained_variance}')