With thanks to Lucky Pattanaik for supplying the majority of this code.

In [5]:
from rdkit import Chem
import numpy as np

In [3]:
base_folder = r'../data/raw/'

train_r_file = base_folder + 'train_reactants.sdf'
train_ts_file = base_folder + 'train_ts.sdf'
train_p_file = base_folder + 'train_products.sdf'

test_r_file = base_folder + 'test_reactants.sdf'
test_ts_file = base_folder + 'test_ts.sdf'
test_p_file = base_folder + 'test_products.sdf'

train_r = Chem.SDMolSupplier(train_r_file, removeHs=False, sanitize=False)
train_r = [x for x in train_r]
train_ts = Chem.SDMolSupplier(train_ts_file, removeHs=False, sanitize=False)
train_ts = [x for x in train_ts]
train_p = Chem.SDMolSupplier(train_p_file, removeHs=False, sanitize=False)
train_p = [x for x in train_p]

test_r = Chem.SDMolSupplier(test_r_file, removeHs=False, sanitize=False)
test_r = [x for x in test_r]
test_ts = Chem.SDMolSupplier(test_ts_file, removeHs=False, sanitize=False)
test_ts = [x for x in test_ts]
test_p = Chem.SDMolSupplier(test_p_file, removeHs=False, sanitize=False)
test_p = [x for x in test_p]

In [6]:
# weights files
# these are the weight matrices they got out of their model

atom_weights_file = 'mit_atom_importance_weights.npy'
atom_weights = np.load(atom_weights_file)

raw_atom_weights_file = 'mit_masked_raw_weights.npy'
raw_weights = np.load(raw_atom_weights_file)

In [10]:
# get bond "scores" based on weight matrices
def get_bond_scores(w, mol):
    scores = []
    for b in mol.GetBonds():
        i = b.GetBeginAtomIdx()
        j = b.GetEndAtomIdx()
        scores.append(w[i,j])
    return scores

In [9]:

def show_important_bonds(plot_idx):
    mol = ts_mols[plot_idx]
    n_atoms = mol.GetNumAtoms()
    w = raw_weights[plot_idx][raw_weights[plot_idx] != 0].reshape(n_atoms,n_atoms)
    bond_weights = np.divide(w,max(np.max(w, axis=1)))

    scores = np.trim_zeros(atom_weights[plot_idx]**4, 'b')
    bond_scores = np.array(get_bond_scores(bond_weights, mol))**12

    highlightAtoms = list(range(len(scores))) #does this by index so if you have atom maps you will have to make a dictionary to translate
    highlightAtomColors={idx:(1-scores[idx],1,1-scores[idx]) for idx in highlightAtoms}
    highlightBonds = list(range(len(bond_scores)))
    highlightBondColors={idx:(1-bond_scores[idx],1,1-bond_scores[idx]) for idx in highlightBonds}

    drawing = draw(mol,highlightAtoms=[], highlightAtomColors=[], \
                   atomScores=scores, highlightBonds=highlightBonds, highlightBondColors=highlightBondColors)

    display(SVG(draw(r_mols[plot_idx])))
    display(SVG(drawing))
    display(SVG(draw(p_mols[plot_idx])))