In [None]:
import os
import pandas as pd
import metrics as mt
import os
import json
from rdkit import Chem
import numpy as np
import torch.nn as nn
from captum.attr import IntegratedGradients
from dgllife.utils import CanonicalAtomFeaturizer
from explainers_callers import (get_model_and_featurizer, get_ig_scores, 
                                get_molgraphx_scores, get_submoleculex_scores, DEVICE)   
from Source.explainers import visualize
from tqdm import tqdm

MODEL_FOLDER = "/home/cairne/WorkSpace/molgraphX/Output/ibench_N"
PATH_TO_SDF = "/home/cairne/WorkSpace/molgraphX/Data/ibenchmark/Datasets/N_train_lbl.sdf"

molecules = [mol for mol in Chem.SDMolSupplier(PATH_TO_SDF) if mol is not None]
molecules.sort(key=lambda mol: mol.GetNumAtoms())

def normalize(x: list[float]) -> list[float]:
    numpy_x = np.array(x)
    norm_x = (numpy_x - numpy_x.mean()) / (numpy_x.std() + 1e-10)
    return norm_x.tolist()

In [None]:
TARGET_MOLECULE = molecules[30]

labels = TARGET_MOLECULE.GetProp("lbls")
model, featurizer = get_model_and_featurizer(MODEL_FOLDER)

molgraphx_scores = get_molgraphx_scores(TARGET_MOLECULE, model, featurizer)
submoleculex_scores = get_submoleculex_scores(TARGET_MOLECULE, model, featurizer)

In [None]:
contribs = {"mols": [], "contribs": []}
for mol in tqdm(molecules[400:500]):
    contribs["mols"].append(mol)
    contribs["contribs"].append(get_molgraphx_scores(mol, model, featurizer))

In [None]:
contribs["contribs_scaled"] = []
for contrib in contribs["contribs"]:
    contribs["contribs_scaled"].append(normalize(contrib))

In [None]:
metrics_df = {"molecule": [], "lbl": [], "contrib": []}
for i, mol in enumerate(contribs["mols"]):
    smiles = Chem.MolToSmiles(mol)
    lbls = mol.GetProp("lbls").split(",")

    if lbls == ["NA"]:
        continue
    lbls = [int(l) for l in lbls]
    contribs_list = contribs["contribs_scaled"][i]
    for lb, contr in zip(lbls, contribs_list):
        metrics_df["molecule"].append(smiles)
        metrics_df["lbl"].append(lb)
        metrics_df["contrib"].append(contr)

metrics_df = pd.DataFrame.from_dict(metrics_df)
# metrics_df.to_csv("submolx_300_400.csv")

In [None]:
import metrics as mt

auc = mt.calc_auc(metrics_df, which_lbls="positive", contrib_col_name="contrib")
np.mean(auc["auc_pos"])

In [None]:
ig_contribs = {"mols": [], "contribs": []}

class CaptumModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x):
        g = graph.clone()
        g.x = x
        return self.model(g)

for mol in tqdm(molecules[:100]):
    ig_contribs["mols"].append(mol)
    graph = featurizer.featurize(TARGET_MOLECULE)

    graph = graph.to(DEVICE)
    model.to(DEVICE)
    model.eval()
    input_tensor = graph.x.clone().detach().requires_grad_(True)
    ig = IntegratedGradients(CaptumModelWrapper(model))
    attributions, approximation_error = ig.attribute(
        input_tensor,
        target=0,
        n_steps=50,
        method='gausslegendre',
        return_convergence_delta=True,
    )
    ig_scores = attributions.detach().cpu().numpy().sum(axis=1).tolist()
    ig_contribs["contribs"].append(ig_scores)

ig_contribs["contribs_scaled"] = []
for contrib in ig_contribs["contribs"]:
    ig_contribs["contribs_scaled"].append(normalize(contrib))

metrics_df = {"molecule": [], "lbl": [], "contrib": []}
for i, mol in enumerate(ig_contribs["mols"]):
    smiles = Chem.MolToSmiles(mol)
    lbls = mol.GetProp("lbls").split(",")

    if lbls == ["NA"]:
        continue
    lbls = [int(l) for l in lbls]
    contribs_list = ig_contribs["contribs_scaled"][i]
    for lb, contr in zip(lbls, contribs_list):
        metrics_df["molecule"].append(smiles)
        metrics_df["lbl"].append(lb)
        metrics_df["contrib"].append(contr)

metrics_df = pd.DataFrame.from_dict(metrics_df)
metrics_df.to_csv("ig.csv")
auc = mt.calc_auc(metrics_df, which_lbls="positive", contrib_col_name="contrib")
np.mean(auc["auc_pos"])

In [None]:
print(normalize(molgraphx_scores))
print(normalize(submoleculex_scores))
print(normalize(ig_scores))

print(labels)

In [None]:
visualize(TARGET_MOLECULE, normalize(molgraphx_scores))


In [None]:
visualize(TARGET_MOLECULE, normalize(submoleculex_scores))


In [None]:
visualize(TARGET_MOLECULE, normalize(ig_scores))
