In [1]:
from rdkit import Chem

with open("/home/cairne/WorkSpace/molgraphX_paper_scripts/Data/100_smiles.txt") as f:
    smiles = [Chem.MolToSmiles(Chem.MolFromSmiles(i)) for i in f.readlines()]


## MolGraphX

In [None]:
import copy
import torch
from argparse import ArgumentParser
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from rdkit import Chem
import sys
import numpy as np
from rdkit.Chem import rdDepictor, AllChem

sys.path.append("/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D")
from Source.explainers.subgraphX.utils import draw_best_subgraph
from Source.explainers.utils import visualize, ExplainableModel
from Source.models.GCNN.featurizers import DGLFeaturizer
from Source.models.GCNN.model import GCNN
from Source.trainer import ModelShell
import json
from explainers_callers import (get_model_and_featurizer, get_ig_scores, get_subgX_scores,
                                get_molgraphx_scores, DEVICE)   
DEVICE = torch.device("cpu")
MODEL = ExplainableModel(
    ModelShell(
        GCNN,
        "/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D/Output/trained_model",
        device=DEVICE
    ))

from Source.rdkit_heatmaps import mapvalues2mol
from Source.rdkit_heatmaps.utils import transform2png

def normalize_scores(scores):
    scores = (scores - scores.mean()) / scores.std()

def visualize(mol, 
            atom_scores: list[float] = None, 
            bond_scores: list[float] = None,
            save_path: str = None, 
            normalize=False, 
            show_values=True,
            set_atom_map=False):
    if atom_scores:
        atom_scores = np.array(atom_scores)
    if bond_scores:
        bond_scores = np.array(bond_scores)

    rdDepictor.Compute2DCoords(mol)
    if normalize: normalize_scores(atom_scores), normalize_scores(bond_scores)

    if show_values:
        if atom_scores is not None:
            for i, atom in enumerate(mol.GetAtoms()):
                atom.SetProp("atomNote", f"{atom_scores[i]:.2f}")


    canvas = mapvalues2mol(mol, 
                        atom_weights=atom_scores, 
                        bond_weights=bond_scores, 
                        # set_atom_map=set_atom_map
                        )
    img = transform2png(canvas.GetDrawingText())
    if save_path is not None: img.save(save_path)
    return img
FEATURIZER = DGLFeaturizer(require_edge_features=False,
                           add_self_loop=False,
                           node_featurizer=CanonicalAtomFeaturizer(),
                           edge_featurizer=CanonicalBondFeaturizer(),
                           canonical_atom_order=False)


results_dict = {}
for smile in smiles:
    try:
        print(smile)
        mol = Chem.MolFromSmiles(smile)
        graph = FEATURIZER.featurize(mol)
        molgraphX_scores = get_molgraphx_scores(mol, featurizer=FEATURIZER, model=MODEL,
                                                )
    
        img = visualize(copy.deepcopy(mol), 
                        atom_scores=molgraphX_scores, 
                        show_values=True,
                        set_atom_map=True)
        img.save(f"/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D/Experiments/molgraphx/{smile}.png")
        results_dict[f"{smile}"] = molgraphX_scores
    except:
        print(smile, "INvalid")
        results_dict[f"{smile}"] = molgraphX_scores



with open("/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D/Experiments/molgraphx_results.json", "w") as jf:
    json.dump(results_dict, jf)

## SubgraphX

In [3]:
subgraphX_kwargs = {
    "mode": "regression",
    "device": DEVICE,
    "explain_graph": True,  # verbose: True,
    "rollout": 20,  # Number of iteration to get the prediction (MCTS hyperparameter)
    "min_atoms": 1,
    "c_puct": 10.0,  # The hyperparameter which encourages the exploration (MCTS hyperparameter)
    "sample_num": None,
    # Sampling time of monte carlo sampling approximation for 'mc_shapley', 'mc_l_shapley' reward_methods
    "reward_method": "l_shapley",  # one of ["gnn_score", "mc_shapley", "l_shapley", "mc_l_shapley", "nc_mc_l_shapley"]
    "subgraph_building_method": "zero_filling",  # one of ["zero_filling", "split"]
}
results_dict = {}
q = 0
for smile in smiles:
    try:
        q += 1
        mol = Chem.MolFromSmiles(smile)
        graph = FEATURIZER.featurize(mol)

        subgraphs = get_subgraphX_subgraphs(mol, featurizer=FEATURIZER, explainable_model=MODEL,
                                            device=DEVICE,
                                            subgraphX_kwargs=subgraphX_kwargs, target_ids=(0,))

        img = draw_best_subgraph(copy.deepcopy(mol), subgraphs, max_nodes=5, show_value=True)
        img.save(f"/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D/Experiments/subgraphx/{smile}.png")
        results_dict[f"{smile}"] = [[], []]
        for sbg in subgraphs:
            results_dict[f"{smile}"][0].append(sbg.coalition)
            results_dict[f"{smile}"][1].append(sbg.P)
        raise
    except:
        pass
print(results_dict)
with open("/home/cairne/WorkSpace/molgraphX_paper_scripts/GCNN_2D/Experiments/subgraphx_results.json", "w") as jf:
    json.dump(results_dict, jf)

{}


In [20]:
subgraphs[1].P

0.4653752566006213

## Integrated Gradients

In [8]:

from captum.attr import Saliency, IntegratedGradients
from collections import defaultdict
import itertools
import numpy as np
from Source.explainers.utils import visualize, ExplainableModel
from rdkit import Chem
from Source.rdkit_heatmaps import mapvalues2mol
from Source.rdkit_heatmaps.utils import transform2png
import copy

MODEL = ModelShell(
        GCNN,
        "/home/cairne/WorkSpace/molgraphX/Output/trained_model",
        device=DEVICE
    ).models[0]

def model_forward(edge_mask, data):
    batch = torch.zeros(data.x.shape[0], dtype=int).to(DEVICE)
    out = MODEL(data, edge_weights=edge_mask)["mu"]
    return out


def explain(method, data, target=0):
    print(data)
    input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(DEVICE)
    if method == 'ig':
        ig = IntegratedGradients(model_forward)
        mask = ig.attribute(input_mask, target=target,
                            additional_forward_args=(data,),
                            internal_batch_size=data.edge_index.shape[1])
    elif method == 'saliency':
        saliency = Saliency(model_forward)
        mask = saliency.attribute(input_mask, target=target,
                                  additional_forward_args=(data,))
    else:
        raise Exception('Unknown explanation method')

    edge_mask = np.abs(mask.cpu().detach().numpy())
    if edge_mask.max() > 0:  # avoid division by zero
        edge_mask = edge_mask / edge_mask.max()
    return edge_mask


def aggregate_edge_directions(edge_mask, data):
    edge_mask_dict = defaultdict(float)
    for val, u, v in list(zip(edge_mask, *data.edge_index)):
        u, v = u.item(), v.item()
        if u > v:
            u, v = v, u
        edge_mask_dict[(u, v)] += val
    return edge_mask_dict


def edge_mask_to_node_mask(edge_mask) -> {}:
    edge_keys = edge_mask.keys()
    num_atoms = max(itertools.chain(*[[i[0], i[1]] for i in edge_keys])) + 1
    atom_vals = []
    for a in range(num_atoms):
        atom_vals.append([])
        for atom_pair in edge_keys:
            if a in atom_pair:
                atom_vals[-1].append(edge_mask[atom_pair])
    for i, val in enumerate(atom_vals):
        atom_vals[i] = np.mean([i.detach() for i in val])
    return atom_vals


results_dict = {}
for smile in smiles:
    sample_molecule = Chem.MolFromSmiles(smile)
    data = FEATURIZER.featurize(sample_molecule)

    input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(DEVICE)
    ig = IntegratedGradients(model_forward)
    mask = ig.attribute(input_mask,
                        additional_forward_args=(data,),
                        internal_batch_size=data.edge_index.shape[1]
                        )

    edge_mask_dict = aggregate_edge_directions(mask, data)
    node_mask = edge_mask_to_node_mask(edge_mask_dict)
    img = visualize(copy.deepcopy(sample_molecule), 
                    atom_scores=node_mask,
                    show_values=True,
                    set_atom_map=True)
    img.save(f"/home/cairne/WorkSpace/molgraphX/Experiments/png_results/ig/{smile}.png")
    results_dict[f"{smile}"] = node_mask

with open("/home/cairne/WorkSpace/molgraphX/Experiments/png_results/ig_results.json", "w") as jf:
    json.dump(results_dict, jf)


  model = model_class(**torch.load(path_to_config))
  state_dict = torch.load(path_to_state, map_location=device)


RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.