In [None]:
import torch
import pickle
from XAI import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from tdc.benchmark_group import admet_group
import os
from torch import stack, tensor, Generator, cat, float32, nonzero, set_float32_matmul_precision

set_float32_matmul_precision('high')
set_global_seed(42)
group = admet_group(path = '../data_tdc/')
names = group.dataset_names

seed = 1
name = names[-2]
ckpt_root = f'./TDC_checkpoints/qm_all_10L_wide_def_2e-5_16p/{name}_{seed}/'
ckpts = [a for a in os.listdir(ckpt_root) if a.startswith('epoch')]
ckpt_path = ckpt_root + ckpts[0]
benchmark = group.get(name)
name = benchmark['name']
benchmark = group.get(name)
        
train_val, test = benchmark['train_val'], benchmark['test']
train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)
smiles = test['Drug'].values.tolist()

loaded_path_hyper_dict = torch.load(ckpt_path)['hyper_parameters']
    
model = GT(
        checkpoint_path = None,
        **loaded_path_hyper_dict
    )
    
model, w = transfer_matching_weights(ckpt_path, model)

dlt = get_iterator(smiles, is_prepared_as_packed_chython=False)
n = 0
c = 0
for batch in dlt:
    if c==n:
        break
    c+=1
result, a_, L  = get_nth_layer(model, batch)

In [None]:
plt.imshow(np.array(a_[0])[1:,1:])

In [None]:
molecules = [chython.smiles(mol) for mol in smiles]
mol = molecules[n]
mol

In [None]:
mol._plane

In [None]:
eigvals, eigvecs = np.linalg.eig((np.array(a_)[0]))
idx = np.flip(np.argsort(np.abs(eigvals)))
eigvals = np.abs(eigvals)[idx]
eigvecs = eigvecs[:,idx]
eigvecs = np.real(eigvecs[1:,:])/np.linalg.norm(np.real(eigvecs[1:,:]),axis = 0)

eigvalsl, eigvecsl = np.linalg.eig((L))
idx = (np.argsort(np.abs(eigvalsl)))
eigvalsl = np.abs(eigvalsl)[idx]
eigvecsl = eigvecsl[:,idx]
eigvecsl = np.real(eigvecsl[1:,:])/np.linalg.norm(np.real(eigvecsl[1:,:]), axis = 0)

In [None]:
import networkx as nx

rdkit_coords = mol._plane
n_ = len(mol)  # Set the desired length
key_mapping = {i: i - 1 for i in range(1, n_ + 1)}
rdkit_coords = {key_mapping[int(old_key)]: value for old_key, value in rdkit_coords.items()}

def rotate_point(point):
    x, y = point
    return np.dot(np.array([[0, -1], [1, 0]]), np.array([x, y]))

rotated_coords = {key: rotate_point(value) for key, value in rdkit_coords.items()}

G_1 = mol.adjacency_matrix()
np.fill_diagonal(G_1, 0)
G_1_n = nx.Graph(G_1)

line_width = 0.5

figsize = (10,6)

plt.figure(figsize = figsize)

j = 1

cmaps = [plt.cm.Reds, plt.cm.Greens, plt.cm.Blues]

for j in range(1,4):

    pos = rotated_coords

    sign = eigvecsl[:,j]@eigvecs[:,j]
    
    if sign>0:
        sign = 1
    else:
        sign = -1

    colors=eigvecsl[:,j]

    cmap=cmaps[j-1]

    vmin = min(colors)
    vmax = max(colors)
    ns = 200

    g = nx.draw_networkx(
        G_1_n, 
        pos, 
        node_size=ns, 
        node_shape='o',
        node_color=eigvecsl[:,j],
        cmap = cmap, 
        with_labels = False,
        edge_color='black'
    )

    g = nx.draw_networkx_nodes(
        G_1_n, 
        pos, 
        node_size=ns,
        node_shape='o',
        node_color=eigvecsl[:,j],
        edgecolors='black', 
        cmap = cmap
        #with_labels = False,
    )

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin = vmin, vmax=vmax))
    sm._A = []
    #plt.colorbar(sm)

    plt.axis('off')
    plt.title(f'eigenvector #{j}')

    colors=sign*eigvecs[:,j]
    vmin = min(colors)
    vmax = max(colors)

    def translate_x(t, d):
        t[0]=t[0]+d
        return t

    pos = {key: translate_x(value, 4.5) for key,value in pos.items()}

    g = nx.draw_networkx(
        G_1_n, 
        pos, 
        node_size=ns,
        node_shape='o',
        node_color=sign*eigvecs[:,j],
        cmap = cmap, 
        with_labels = False,
    )

    g = nx.draw_networkx_nodes(
        G_1_n, 
        pos, 
        node_size=ns,
        node_shape='o',
        node_color=sign*eigvecs[:,j],
        edgecolors='black', 
        cmap = cmap
    )

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin = vmin, vmax=vmax))
    sm._A = []
    plt.colorbar(sm)
    plt.axis('off')
    plt.title(f'Eigenvector #{j}')
    plt.tight_layout()
    plt.savefig(f'./visual_modes/eigvec_{j}.png')
    plt.clf()