## This notebook uses RDMC to generate initial TS guess from user-provided atom-mapped reaction smiles

### currently implemented for bi-molecular H abstraction reactions
### There's a parallelized version available
### last updated by Oscar Sep8 2023

In [1]:
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath('/home/oscar/code/RDMC/rdmc')))

import pickle
import multiprocessing as mp

from itertools import chain
from rdmc import RDKitMol
from rdmc.ts import get_formed_and_broken_bonds
from rdmc.view import ts_viewer
from rdmc.view import mol_viewer
from rdkit.Chem import AllChem

from rdmc.forcefield import RDKitFF

import numpy as np
import pandas as pd
import json
import copy

In [2]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Function to generate a bi-habs TS example

In [3]:
# here we pick a force field for embedding initial reactant/product geometry
ps = AllChem.ETKDGv3()
# this mkae sure we get different embedding each time
ps.randomSeed=np.random.randint(1,10000000)

def gen_ts(rxn_smi,
           angle_X_H_Y,
           dihedral_1,
           dihedral_2,
           h_extractor_regexpr = None, # need update
           bond_length_scale_factor=1.25,
           bond_length_X_H = None, # optional 
           bond_length_H_Y = None, # optional 
          ):
    """
    Generate a bi-molecular h abstraction TS example. Other reaction family not implemented yet.
    
    Input:
        rxn_smi: atom-mapped reaction smiles 
        angle_X_H_Y: bond angle (degree) between fragment X, H, and fragment Y in the TS
        dihedral_1: bond dihedral angle (degree) for the reaction center defined by one part of the fragment
        dihedral_2: bond dihedral angle (degree) for the reaction center defined by the other part of 
                    the fragment (different from dihedral_1)
        bond_length_scale_factor: this is empricial constant, can be learned 
    Optional Input:
        bond_length_X_H: bond length (angstrom) between fragment X and H in the TS (X - H - Y)
        bond_length_H_Y: bond length (angstrom) between fragment Y and H in the TS (X - H - Y)
        note: the bond lengths are automatically determined from reactant and product geometry unless specified by user
              todo: overwrite with user provided info if value is given
        h_extractor_regexpr: todo: this is hardcoded, need to be improved. this is for finding the ROO* radical in the reaction smile. this was implemented because sometimes it's difficult to figure out which species is the ROO* extractor (see the TS geometry illustration below), 
                             but we should have a better way of handling this. or ask the user to give standardized reaction smile to this function.
        
    Output:
        r_complex: RDkitMol
        p_complex: RDkitMol
        ts: RDkitMol
        
    
    Notation for TS geometry 
    
    r2_end_bulky_atom - ts_r2_atom --- ts_h_atom --- ts_r1oo_atom - r1oo_o2_atom - r1oo_end_atom
    
            R2X       -    R2E     ---    TS     ---    R1O       -      R1OO    -     R1X   
            
           this part is reactant 2, r2                  this part is reactant 1, r1
           we assume r2 is the closed shell             we assume r1 is the ROO* radical  
           molecule being attacked
    """
    
    # generate reactant and product complex RDkitMOl from smiles
    r_complex, p_complex = [RDKitMol.FromSmiles(smi) for smi in rxn_smi.split(">>")]
    
    # perceive reaction center (bi-habs)
    fbond, bbond = get_formed_and_broken_bonds(r_complex, p_complex) # formed, broken bonds indices 
    the_h_atom = list(set(chain(*fbond)).intersection(chain(*bbond))) # the H atom index in the TS
    pivot_atoms = list(set(chain(*(fbond+bbond))).difference(the_h_atom)) # TS dihedral indices 
    
    # generate TS using reactants 
    # get atom indexes in each molecule fragments
    frags = r_complex.GetMolFrags() # reactants complex
    frags_p = p_complex.GetMolFrags() # products complex 
    
    # embed 3D geometry
    r_complex.EmbedConformer(ps) # here we use ETKDGv3() defined on top of the notebook, but can be changed to others if needed 
    ts = r_complex.AddRedundantBonds(fbond) # we need to add redundant bond to the reactant complex graph to represent the TS geometry 
    conf = ts.GetConformer() # embed TS conformer
    
    # set TS bond length
    # we just embedded a TS conformer but without setting the bond distance. this step is crucial for having a good TS guess geometry
    
    # here we try to figure out which species are expected r1, r2, p1, p2 from given reaction smiles
    # lots of hardcode here, maybe we should really ask the user to provide us smiles that match our expectation (or we have another function to pre-process the smiles before this function, since the main purpose of the function here is 
    # really to generate TS but not trying to correct smiles)
    smi_ordered, smi_original, spc_idx = determine_rp_order(rxn_smi, regexpr=h_extractor_regexpr) 
    r1, r2, p1, p2 = smi_ordered # change to expected order
    _r1, _r2, _p1, _p2 = smi_original # store original order in case we need it  
    
    # this check which one is r1 or r2 base on TS atom index, hardcoded, notice that the following code combined with the the above one might be a bit confusing, but the net effect is that we want to make sure r1_dist is the ROO-H distance, 
    # and r2_dist is the other one
    if ':' + str(pivot_atoms[0]) + ']' in r2:
        r2_idx = 0
        r1_idx = 1
    else:
        r2_idx = 1
        r1_idx = 0
    
    # update bond distance of TS H atom -- r2
    
    # step 1-1: we find the reacting bond distance of r2 when it's not in the TS based on averaging of several conformers   
    r2_dist = return_opt_spc_bond_distance(spc_smi = r2,
                            pivot_atom = pivot_atoms[r2_idx],
                            the_h_atom = the_h_atom[0],
                            atom_map_idx = frags[1],  
                            )
    # step 2-1: we scale the reacting bond with a scaling factor, the factor is provided by user currently, but it can be learned since it depends on reactant/product type (usually 25% is good starting guess for C-H, 15% for O-H etc)
    bond_length_H_Y = r2_dist * bond_length_scale_factor
    conf.SetBondLength([pivot_atoms[r2_idx], the_h_atom[0]], bond_length_H_Y)
    
    
    # we do the same thing to update bond distance of TS H atom -- r1
    # note here we use product information to figure out the distance 
    # step: 1-2
    r1_dist = return_opt_spc_bond_distance(spc_smi = p1,
                            pivot_atom = pivot_atoms[r1_idx],
                            the_h_atom = the_h_atom[0],
                            atom_map_idx = frags_p[0],  
                            )
    
    # step: 2-2
    bond_length_X_H = r1_dist * bond_length_scale_factor # in theory, this scale factor for product does not have to be the same as for reactant, we can seperate them; here use the same for demo 
    
    
    # step 3: set bond length for the TS complex 
    conf.SetBondLength([pivot_atoms[r1_idx], the_h_atom[0]], bond_length_X_H)
    
    # set TS angle
    conf.SetAngleDeg([pivot_atoms[0], the_h_atom[0], pivot_atoms[1]], angle_X_H_Y) # angle_X_H_Y can be learned, but usually 160 deg is good for most bi-mol H abs family
    
    # set TS dihedral
    # note: there are two sets of dihedrals for the TS, since a dihedral is defined by 4 connecting atoms, and there are 5 atoms connected in the TS region, thus 2 sets of dihedrals 
    
    # to set dihedrals, we need to first figure out the indices of neighboring atoms 
    neighbor_r1 = [atom.GetIdx() for atom in ts.GetAtomWithIdx(pivot_atoms[r1_idx]).GetNeighbors() if atom.GetIdx() != the_h_atom[0]][0]
    neighbor_r2 = [atom.GetIdx() for atom in ts.GetAtomWithIdx(pivot_atoms[r2_idx]).GetNeighbors() if atom.GetIdx() != the_h_atom[0]][0]
    
    # set dihedrals
    # note: the dihedrals can actually be searched or learned, but not implemented yet, currently we set to user provided value
    conf.SetTorsionDeg([neighbor_r1, pivot_atoms[r1_idx], the_h_atom[0], pivot_atoms[r2_idx]], dihedral_1)
    conf.SetTorsionDeg([pivot_atoms[r1_idx], the_h_atom[0], pivot_atoms[r2_idx], neighbor_r2], dihedral_2)
    
    # summarize of ts geometry info (atom indices), mostly for easy manipulation later 
    ts_geom_info = tuple([the_h_atom[0], pivot_atoms[r1_idx], pivot_atoms[r2_idx], neighbor_r1, neighbor_r2])

    # store indices for ts bonds, and initial bond distances for later analysis purposes
    ts_bd_idx = tuple([r1_idx, r2_idx])
    ts_est_bd_dist = tuple([bond_length_X_H, bond_length_H_Y])
    
    return r_complex, p_complex, ts, frags, ts_geom_info, ts_est_bd_dist, ts_bd_idx

In [4]:
# this function calcualte R-H bond distances for closed-shell reactant/product species based on averaging of N conformers embedded from force field 
# a helper function used in gen_ts
def return_opt_spc_bond_distance(
    spc_smi,
    pivot_atom,
    the_h_atom,
    atom_map_idx,    
    averaged = 5,
    ):

    """
    Return the optimized bond distance of reacting R-H from reactant/product species. Only use on non-radical RH species. 
    """
    
    bd_list = list()
    
    while averaged:
        r = RDKitMol.FromSmiles(spc_smi)
        ff = RDKitFF('mmff94s')
        r.EmbedConformer(ps)
        ff.setup(r)
        ff.optimize()
        m = ff.get_optimized_mol()

        r_conf = r.GetConformer()
        r_conf.SetPositions(m.GetPositions())

        h_idx = atom_map_idx.index(the_h_atom)
        pivot_idx = atom_map_idx.index(pivot_atom)

        bd = r_conf.GetBondLength([h_idx, pivot_idx])
        bd_list.append(bd)
        averaged -= 1
        
    return sum(bd_list)/len(bd_list)   

In [5]:
# this is a helper function to standardize reaction smile and determine the expected reactant/product order
# a helper function used in gen_ts
# we can probably improve this function as this is pretty hard-coded
def determine_rp_order(rxn_smi, regexpr):
    
    _r1, _r2 = rxn_smi.split(">>")[0].split('.')
    _p1, _p2 = rxn_smi.split(">>")[1].split('.')
    
    
    if regexpr is None:
        r1 = _r1
        r2 = _r2
        
        r1_idx = 0
        r2_idx = 1
        
        p1 = _p1
        p2 = _p2
        
        p1_idx = 0
        p2_idx = 1

        return (r1, r2, p1, p2), (_r1, _r2, _p1, _p2), (r1_idx, r2_idx, p1_idx, p2_idx)
        
    if regexpr in _r1:
        r1 = _r1
        r2 = _r2
        
        r1_idx = 0
        r2_idx = 1
    else:
        r1 = _r2
        r2 = _r1
        
        r1_idx = 1
        r2_idx = 0
        
    if regexpr in _p1:
        p1 = _p1
        p2 = _p2
        
        p1_idx = 0
        p2_idx = 1
    else:
        p1 = _p2
        p2 = _p1
        
        p1_idx = 1
        p2_idx = 0
    
    return (r1, r2, p1, p2), (_r1, _r2, _p1, _p2), (r1_idx, r2_idx, p1_idx, p2_idx)

In [6]:
# test of determine_rp_order
# notice that specifying `regexpr` will have an effect on the outcome

In [7]:
rxn_smi = '[H:11][C:16]([H:12])([H:13])[C:17]([H:14])([H:15])[F:18].[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8]>>[H:12][C:16]([H:13])[C:17]([H:14])([H:15])[F:18].[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8][H:11]'

In [8]:
determine_rp_order(rxn_smi, regexpr=None)

(('[H:11][C:16]([H:12])([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8]',
  '[H:12][C:16]([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8][H:11]'),
 ('[H:11][C:16]([H:12])([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8]',
  '[H:12][C:16]([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8][H:11]'),
 (0, 1, 0, 1))

In [9]:
determine_rp_order(rxn_smi, regexpr='[O:9][O:8]')

(('[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8]',
  '[H:11][C:16]([H:12])([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8][H:11]',
  '[H:12][C:16]([H:13])[C:17]([H:14])([H:15])[F:18]'),
 ('[H:11][C:16]([H:12])([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8]',
  '[H:12][C:16]([H:13])[C:17]([H:14])([H:15])[F:18]',
  '[H:1][C:6]([H:2])([H:3])[S:10][C:7]([H:4])([H:5])[O:9][O:8][H:11]'),
 (1, 0, 1, 0))

In [10]:
# this function check if the TS guess has atoms colliding with each other
# 0.4 anstrom is an empirical parameter 
def check_hard_collision(ts, threshold=0.4):
    if ts.HasCollidingAtoms(threshold=threshold):
        raise ValueError('Atom collision detected.')

In [11]:
# this function check if two parts of the TS are too close to each other
# since the TS bond distance is about 1.3 anstrom for C-H, if two fragements are closer than this, additional bonds may form (collision)
def check_fragment_collision(ts, frags, threshold=1.3):
    distance_matrix = np.triu(ts.GetDistanceMatrix())
    index = frags[0][-1] + 1 # atom index of second fragment in the reactant complex 
    fragment_bond_distance_matrix = distance_matrix[0:index, index:] # check if two fragments are too close to each other 
    if np.any(fragment_bond_distance_matrix < threshold):
        raise ValueError('Reactants collision detected.')
    else:
        small_distances = fragment_bond_distance_matrix[fragment_bond_distance_matrix < 2]
        relax_score = np.sum((small_distances-2)**2)
        return relax_score # notice that we return a score where the larger the score, the more seperated the two fragments of the TS are (ideal for initial guess)

In [12]:
# a helper function to optimize the transition state using force field
# note: this is a constrained optimization that fixes the two TS bond distances first, 
# then, we optimize the geometry to a stable point to relax it (not to a saddle point as in typical ts search)
# the idea here is to relax each fragment of the ts to make the geometry more reasonable 
def opt_ts_ff(rxn_smi, ts, ts_geom_info, ts_est_bd_dist, ts_bd_idx):
    
    bond_length_X_H, bond_length_H_Y = ts_est_bd_dist
    r1_idx, r2_idx = ts_bd_idx
    
    ff = RDKitFF('mmff94s')
    fake_ts = RDKitMol.FromSmiles(rxn_smi.split('>>')[0])
    fake_ts.EmbedConformer()
    fake_ts.SetPositions(ts.GetPositions())
    ff.setup(fake_ts)
    
    the_h_atom = ts_geom_info[0]
    pivot_atom_r1 = ts_geom_info[1]
    pivot_atom_r2 = ts_geom_info[2]
    
    ff.add_distance_constraint(atoms=[the_h_atom,pivot_atom_r1], value=bond_length_X_H)
    ff.add_distance_constraint(atoms=[the_h_atom,pivot_atom_r2], value=bond_length_H_Y)
    
    ff.optimize()
    m = ff.get_optimized_mol()
    
    ts_new = copy.deepcopy(ts)
    ts_new.SetPositions(m.GetPositions())
    
    return ts_new 

In [13]:
# a helper function attempt to generate at least 1 valid ts guess (one that does not have colliding atoms or fragments)
# this func will embed max_iter number of times for the ts complex, and see if it is possible to get 1 valid TS guess (may fail after max attempt)
# expensive step, can be optimized 
def gen_valid_ts(rxn_smi,
           h_extractor_regexpr,
           bond_length_scale_factor,
           angle_X_H_Y,
           dihedral_1,
           dihedral_2,
           max_iter,
           opt = True,
          ):

    r_complex = None
    p_complex = None
    ts = None
    frags = None
    
    counter = 0
    relax_score = None
    while relax_score is None and counter < max_iter:
        try:
            r_complex, p_complex, ts, frags, ts_geom_info, ts_est_bd_dist, ts_bd_idx = gen_ts(rxn_smi=rxn_smi,
               angle_X_H_Y=angle_X_H_Y,
               dihedral_1=dihedral_1,
               dihedral_2=dihedral_2,
               h_extractor_regexpr=h_extractor_regexpr,
               bond_length_scale_factor=bond_length_scale_factor, 
              )

            check_hard_collision(ts)
            
            if opt:
                ts = opt_ts_ff(rxn_smi, ts, ts_geom_info, ts_est_bd_dist, ts_bd_idx)
            
            threshold = min(ts_est_bd_dist) * 0.98
            relax_score = check_fragment_collision(ts, frags, threshold=threshold)
            
        except:
            pass
        finally:
            counter += 1
    
    return r_complex, p_complex, ts, frags, relax_score, ts_geom_info, ts_est_bd_dist

In [14]:
# a helper function to attempt to generate N valid TS guesses, each with a score upto some max iteration 
# expensive step, can be optimized 
def gen_n_ts_confs(
           rxn_smi,
           angle_X_H_Y,
           dihedral_1,
           dihedral_2,
           h_extractor_regexpr,
           bond_length_scale_factor=1.25,
           max_iter_per_conf = 50,
           num_confs = 5,
           max_total_iter = 20,
):
    
    result = list()
    result_count = len(result)
    
    iter_counter = 0
    while result_count < num_confs and iter_counter < max_total_iter:
        try:
            r_complex, p_complex, ts, frags, relax_score, ts_geom_info, ts_est_bd_dist = gen_valid_ts(rxn_smi=rxn_smi,
               h_extractor_regexpr=h_extractor_regexpr,
               bond_length_scale_factor=bond_length_scale_factor,
               angle_X_H_Y=angle_X_H_Y,
               dihedral_1=dihedral_1,
               dihedral_2=dihedral_2,
               max_iter=max_iter_per_conf,
              )
            if all([relax_score, r_complex, p_complex, ts, frags]):
                xyz = ts.ToXYZ()
                g_xyz = "\n".join([l for l in xyz.splitlines()[2:]]) + "\n\n"
                result.append(tuple([relax_score, r_complex, p_complex, ts, frags, g_xyz, ts_geom_info, ts_est_bd_dist])) 
        except:
            pass
        finally:
            iter_counter += 1
            result_count = len(result)
    
    if not result:
        raise ValueError('Failed to generate TS conformers.')
    else:
        result.sort(key=lambda y:y[0])
        result_g_xyz = [x[5] for x in result]
        return result, result_g_xyz

# Test TS generation (expensive step)

In [15]:
rxn_smi = '[Br:4][c:10]1[c:8]([C:14]([O:6][H:18])=[O:7])[c:9]([H:15])[c:11]([Cl:5])[c:12]([H:16])[c:13]1[H:17].[H:1][O:3][O:2]>>[Br:4][c:10]1[c:8]([C:14]([O:6])=[O:7])[c:9]([H:15])[c:11]([Cl:5])[c:12]([H:16])[c:13]1[H:17].[H:1][O:3][O:2][H:18]'

In [16]:
r_complex, p_complex, ts, frags, relax_score, ts_geom_info, ts_est_bd_dist = gen_valid_ts(rxn_smi,
           h_extractor_regexpr='[O:3][O:2]',
           bond_length_scale_factor=1.25,
           angle_X_H_Y=120,
           dihedral_1=0,
           dihedral_2=0,
           max_iter=20,
           opt = True,
          )

In [17]:
ts_viewer(r_complex, p_complex, ts, only_ts=True)

<py3Dmol.view at 0x147d81176ef0>

In [18]:
ts_est_bd_dist

(1.2200915388985287, 1.225853404121895)

In [19]:
result, result_g_xyz = gen_n_ts_confs(
           rxn_smi,
           angle_X_H_Y=120,
           dihedral_1=0,
           dihedral_2=0,
           h_extractor_regexpr='[O:3][O:2]',
           bond_length_scale_factor=1.25,
           max_iter_per_conf = 20,
           num_confs = 2,
           max_total_iter = 10,
)

In [20]:
# best conformer

In [21]:
relax_score, r_complex, p_complex, ts, frags, g_xyz, ts_geom_info, ts_est_bd_dist = result[0]

In [22]:
ts_viewer(r_complex, p_complex, ts, only_ts=True)

<py3Dmol.view at 0x147d81177280>

In [23]:
print(relax_score)

0.676849902910381


In [24]:
print(g_xyz)

H      0.320639    0.224111    1.415470
O      0.901271   -0.457798   -0.368588
O      0.102366   -0.512611    0.800851
Br     5.396268    1.743884    0.776691
Cl     2.010394    7.123008    1.611902
O      2.391214    1.478813   -0.370729
O      1.310524    1.639306    1.600773
C      2.936238    3.284210    0.983649
C      2.213121    4.465333    1.200435
C      4.335832    3.307930    0.954951
C      2.893738    5.669664    1.357452
C      4.286296    5.706183    1.310398
C      5.007654    4.525409    1.113902
C      2.150993    2.053032    0.820957
H      1.126389    4.441207    1.239063
H      4.816979    6.647443    1.433956
H      6.094886    4.565416    1.095170
H      1.659895    0.497401   -0.409232




In [25]:
# worst conformer

In [26]:
relax_score, r_complex, p_complex, ts, frags, g_xyz, ts_geom_info, ts_est_bd_dist = result[-1]

In [27]:
ts_viewer(r_complex, p_complex, ts, only_ts=True)

<py3Dmol.view at 0x147d81176ce0>

In [28]:
print(relax_score)

0.6801572975871264


In [29]:
print(g_xyz)

H     -0.207407   -1.044333    0.129154
O      1.457452   -0.097870   -0.434650
O      0.186545   -0.670309   -0.691737
Br    -1.676484   -0.459719    4.311452
Cl     4.259308   -1.642387    6.426352
O      2.043109   -0.170581    1.936908
O      0.076535   -1.242365    1.824917
C      1.180089   -0.944222    3.959305
C      2.471670   -1.189765    4.456624
C      0.102546   -0.847880    4.850961
C      2.676631   -1.338302    5.826245
C      1.607749   -1.246607    6.713368
C      0.322334   -0.999848    6.226693
C      1.004986   -0.827023    2.494358
H      3.314988   -1.269603    3.773685
H      1.767016   -1.362181    7.782894
H     -0.503153   -0.922132    6.931755
H      1.772968   -0.117144    0.744182


