### This notebook walks you through how to use the pretrained model to generate your own transition state guesses. If using your own model and data, replace the model and data paths with your own

Import the necessary packages

In [2]:
from rdkit import Chem, Geometry
from rdkit.Chem.Draw import IPythonConsole 

import tensorflow as tf
from model.G2C import G2C

import numpy as np
import py3Dmol

Read in the test data with rdkit. You can optionally change this to use your own file type (instead of sdf), as long as rdkit can read in the data without sanitization or hydrogen removal. Note that the reactants and prodcuts defined in the sdf **MUST** preserve atom ordering between them! Define the saved model path

In [3]:
model_path = 'log/2layers_256hs_3its/best_model.ckpt'
reactant_file = 'data/test_reactants.sdf'
product_file = 'data/test_products.sdf'

test_data = [Chem.ForwardSDMolSupplier(reactant_file, removeHs=False, sanitize=False),
             Chem.ForwardSDMolSupplier(product_file, removeHs=False, sanitize=False)]
test_data = [(x,y) for (x,y) in zip(test_data[0], test_data[1]) if (x,y)]

Batch preparation code. You can use larger batch sizes if you have many predictions to speed up the predictions

In [4]:
BATCH_SIZE = 1
MAX_SIZE = max([x.GetNumAtoms() for x,y in test_data])
elements = "HCNO"; num_elements = len(elements)

def prepare_batch(batch_mols):

    # Initialization
    size = len(batch_mols)
    V = np.zeros((size, MAX_SIZE, num_elements+1), dtype=np.float32)
    E = np.zeros((size, MAX_SIZE, MAX_SIZE, 3), dtype=np.float32)
    sizes = np.zeros(size, dtype=np.int32)
    coordinates = np.zeros((size, MAX_SIZE, 3), dtype=np.float32)

    # Build atom features
    for bx in range(size):
        reactant, product = batch_mols[bx]
        N_atoms = reactant.GetNumAtoms()
        sizes[bx] = int(N_atoms)

        # Topological distances matrix
        MAX_D = 10.
        D = (Chem.GetDistanceMatrix(reactant) + Chem.GetDistanceMatrix(product)) / 2
        D[D > MAX_D] = 10.

        D_3D_rbf = np.exp(-((Chem.Get3DDistanceMatrix(reactant) + Chem.Get3DDistanceMatrix(product)) / 2))  # squared

        for i in range(N_atoms):
            # Edge features
            for j in range(N_atoms):
                E[bx, i, j, 2] = D_3D_rbf[i][j]
                if D[i][j] == 1.:  # if stays bonded
                    if reactant.GetBondBetweenAtoms(i, j).GetIsAromatic():
                        E[bx, i, j, 0] = 1.
                    E[bx, i, j, 1] = 1.

            # Recover coordinates
            # for k, mol_typ in enumerate([reactant, ts, product]):
            pos = reactant.GetConformer().GetAtomPosition(i)
            np.asarray([pos.x, pos.y, pos.z])
            coordinates[bx, i, :] = np.asarray([pos.x, pos.y, pos.z])

            # Node features
            atom = reactant.GetAtomWithIdx(i)
            e_ix = elements.index(atom.GetSymbol())
            V[bx, i, e_ix] = 1.
            V[bx, i, num_elements] = atom.GetAtomicNum() / 10.

    batch_dict = {
        "nodes": V,
        "edges": E,
        "sizes": sizes,
        "coordinates": coordinates
    }
    return batch_dict, batch_mols


def sample_batch():
    batches = (len(test_data) - 1) // BATCH_SIZE + 1
    for i in range(batches):
        batch_mols = test_data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        yield prepare_batch(batch_mols)

Initialize the model. The hyperparameters should match those of the previously trained model (pls ignore all the deprecation warnings :) )

In [5]:
model = G2C(
      max_size=MAX_SIZE, node_features=num_elements+1, edge_features=3, layers=2, hidden_size=256, iterations=3
)

W1007 09:57:50.280781 140199825712512 deprecation_wrapper.py:119] From model/G2C.py:29: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W1007 09:57:50.304795 140199825712512 deprecation_wrapper.py:119] From model/G2C.py:33: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W1007 09:57:50.343178 140199825712512 deprecation.py:323] From model/GNN.py:87: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
W1007 09:57:50.354985 140199825712512 deprecation.py:506] From /home/lagnajit/anaconda3/envs/ts_gen/lib/python2.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling __init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W100

Load the trained model and predict transition state geometries!

In [6]:
# Launch session
config = tf.ConfigProto(
    allow_soft_placement=True,
    log_device_placement=False
)
with tf.Session(config=config) as sess:
    
    # Initialization
    print("Model loading...")
    saver = tf.train.Saver()
    saver.restore(sess, model_path)
    print("Model restored")
    
    # Generator for test data
    get_test_data = sample_batch()

    X = np.empty([len(test_data), MAX_SIZE, 3])
    
    for step, data in enumerate(get_test_data):

        batch_dict_test, batch_mols_test = data
        feed_dict = {
            model.placeholders[key]: batch_dict_test[key] for key in batch_dict_test
        }
        X[step*BATCH_SIZE:(step+1)*BATCH_SIZE, :, :] = sess.run([model.tensors["X"]], feed_dict=feed_dict)[0]

W1007 09:57:57.568586 140199825712512 deprecation.py:323] From /home/lagnajit/anaconda3/envs/ts_gen/lib/python2.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Model loading...
Model restored


Convert geometries into rdkit mol objects and save the geometries as an sdf

In [7]:
ts_mols = []
for bx in range(X.shape[0]):
    
    # Make copy of reactant
    mol_target = test_data[bx][0]
    mol = Chem.Mol(mol_target)

    for i in range(mol.GetNumAtoms()):
        x = X[bx, i, :].tolist()
        mol.GetConformer().SetAtomPosition(
            i, Geometry.Point3D(x[0], x[1], x[2])
        )
    ts_mols.append(mol)


model_ts_file = 'data/model_ts.sdf'
ts_writer = Chem.SDWriter(model_ts_file)
for i in range(len(ts_mols)):
    ts_writer.write(ts_mols[i])

Visualize the results. Change n to see different combinations of reactants, transition states, and products. Note that, for the TS, rdkit will add bonds based on the reactant. We'll clean this to only include common bonds between reactants and products

In [8]:
def clean_ts(mols):
    
    r_mol, ts_mol, p_mol = mols
    r_bonds = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in r_mol.GetBonds()]
    r_bonds = [tuple(sorted(b)) for b in r_bonds]
    p_bonds = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in p_mol.GetBonds()]
    p_bonds = [tuple(sorted(b)) for b in p_bonds]
    common_bonds = list(set(r_bonds) & set(p_bonds))
    
    emol = Chem.EditableMol(ts_mol)
    for bond in ts_mol.GetBonds():
        bond_idxs = tuple(sorted((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())))
        if bond_idxs not in common_bonds:
            emol.RemoveBond(bond_idxs[0], bond_idxs[1])
            emol.AddBond(bond_idxs[0], bond_idxs[1])
    return r_mol, emol.GetMol(), p_mol


def show_mol(mol, view, grid):
    mb = Chem.MolToMolBlock(mol)
    view.removeAllModels(viewer=grid)
    view.addModel(mb,'sdf', viewer=grid)
    view.setStyle({'model':0},{'stick': {}}, viewer=grid)
    view.zoomTo(viewer=grid)
    return view

In [9]:
n=1
mols = [test_data[n][0], ts_mols[n], test_data[n][1]]
view_mols = clean_ts(mols)

view = py3Dmol.view(width=960, height=500, linked=False, viewergrid=(1,3))
for i in range(3):
    show_mol(view_mols[i], view, grid=(0, i))
view.render()

<py3Dmol.view at 0x7f828136f2d0>