In [287]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [288]:
from copy import copy

from tooltoad.chemutils import xyz2mol
from tooltoad.vis import MolTo3DGrid

from rdkit import Chem
from rdkit.Chem import rdDetermineBonds, rdMolAlign, rdchem
from rdkit.Chem.AllChem import ETKDGv3, EmbedMolecule
from rdkit.Chem.rdchem import RWMol
from rdkit.Geometry.rdGeometry import Point3D

from frust.utils.mols import fix_cat_frag, get_molecule_name

In [289]:
ligand_smiles       = "CC([Si](N1C=CC=C1)(C(C)C)C(C)C)C"
ligand_smiles       = "CC(C)(C)NC(=O)Oc1cccc(Cl)c1"
ts_guess_struct = "../structures/ts3_guess.xyz"
pre_name            = "TS1"
embed_ready         = True

In [290]:
try:
    with open(ts_guess_struct, 'r') as file:
        xyz_block = file.read()
except FileNotFoundError:
    print(f"Error: Transition state structure file not found: {ts_guess_struct}")
    raise
except PermissionError:
    print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
    raise
except IOError as e:
    print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
    raise
except Exception as e:
    print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
    raise

ts = xyz2mol(xyz_block)
MolTo3DGrid(ts, background_color="lightblue", cell_size=(600,600))

In [291]:
# --- Find unique positions and check that they are valid cH --- #
lig_mol = Chem.MolFromSmiles(ligand_smiles)
lig_mol = Chem.AddHs(lig_mol)

cH_patt = Chem.MolFromSmarts('[cH]')
matches = lig_mol.GetSubstructMatches(cH_patt)
cH_atoms = [ind[0] for ind in matches]

atom_rank = list(Chem.CanonicalRankAtoms(lig_mol,breakTies=False))

def find_unique_atoms(lst):
    seen = set()
    result = []
    for i, x in enumerate(lst):
        if x not in seen:
            result.append(i)
            seen.add(x)
    return result

unique_atoms = find_unique_atoms(atom_rank)
unique_cH = set(unique_atoms).intersection(set(cH_atoms))
unique_cH = tuple(unique_cH)

MolTo3DGrid(lig_mol, legends=[f"Uniques: {unique_cH}"])

In [292]:
ts_rw = RWMol(ts)
ts_rw_origin = copy(ts_rw)
ts_ligand_pattern = Chem.MolFromSmarts("N1CCCC1")
old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)

# --- Create aligned maps --- #
old_active_site = old_ring_match[0:3]

maps = []
for a in unique_cH:
    C_pos = lig_mol.GetAtomWithIdx(a)
    nbs = []
    for nb in C_pos.GetNeighbors():
        if nb.GetAtomicNum() == 1:
            pass # hydrogen
        else:
            nbs.append(nb.GetIdx())
    
    nbs.insert(1, C_pos.GetIdx())
    
    map = []
    for nb, aa in zip(nbs, old_active_site):
        map.append((nb, aa))
    maps.append(map)

In [293]:
frags = Chem.GetMolFrags(ts_rw, asMols=True)

In [294]:
params = ETKDGv3()
params.randomSeed = 0xF00D
ts_mols = {}
for map in maps:
    rpos = map[1][0]
    EmbedMolecule(lig_mol, params)
    ts_rw = Chem.RWMol(ts_rw_origin)
    rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

    # --- remove hydrogen from the reacting carbon --- #
    chosen_carbon_idx = rpos
    chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

    for nb in chosen_carbon.GetNeighbors():
        if nb.GetAtomicNum() == 1:  # hydrogen
            lig_mol_rw = RWMol(lig_mol)
            lig_mol_rw.RemoveAtom(nb.GetIdx())
            break

    # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
    n_pattern_full = Chem.MolFromSmiles("CN1CCCC1")
    n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

    atoms_to_remove = set()
    for idx in n_old_indices:
        atom = ts_rw.GetAtomWithIdx(idx)
        atoms_to_remove.add(idx)
        for neighbor in atom.GetNeighbors():
            if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                atoms_to_remove.add(neighbor.GetIdx())

    for idx in sorted(atoms_to_remove, reverse=True):
        ts_rw.RemoveAtom(idx)

    frags = Chem.GetMolFrags(ts_rw, asMols=True)
    fixed_cat = fix_cat_frag(frags[0])
    fixed_Hs  = frags[1]

    ts_rw = RWMol(Chem.CombineMols(fixed_cat, fixed_Hs))
    offset = ts_rw.GetNumAtoms()
    reactive_H1 = offset - 1
    reactive_H2 = offset - 2

    ts_combined = Chem.CombineMols(ts_rw, lig_mol_rw)
    ts_rw_combined = RWMol(ts_combined)

    offset = ts_rw.GetNumAtoms()
    reactive_C = rpos + offset

    cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
    B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
    N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]

    B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
    Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

    ts_rw_combined.AddBond(B_cat_idx, reactive_C, Chem.BondType.SINGLE)
    
    atom_indices_to_keep = [B_cat_idx, N_cat_idx]
    atom_indices_to_keep.extend(Hs_on_B)
    atom_indices_to_keep.extend([reactive_H1, reactive_H2, reactive_C])

    if embed_ready:
        ts_rw_combined.AddBond(B_cat_idx, reactive_H1, Chem.BondType.ZERO)

In [295]:
atom_indices_to_keep

[10, 17, 39, 41, 40, 56]

In [296]:
MolTo3DGrid(ts_rw_combined, cell_size=(600,600))

In [297]:
def transformer_ts3(
    ligand_smiles       = "CC([Si](N1C=CC=C1)(C(C)C)C(C)C)C",
    ts_guess_struct     = "ts1_guess.xyz",
    pre_name            = "TS1",
    embed_ready         = True,        
):
    try:
        with open(ts_guess_struct, 'r') as file:
            xyz_block = file.read()
    except FileNotFoundError:
        print(f"Error: Transition state structure file not found: {ts_guess_struct}")
        raise
    except PermissionError:
        print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
        raise
    except IOError as e:
        print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
        raise

    ts = Chem.MolFromXYZBlock(xyz_block)
    rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)

    # --- Find unique positions and check that they are valid cH --- #
    lig_mol = Chem.MolFromSmiles(ligand_smiles)
    lig_mol = Chem.AddHs(lig_mol)

    cH_patt = Chem.MolFromSmarts('[cH]')
    matches = lig_mol.GetSubstructMatches(cH_patt)
    cH_atoms = [ind[0] for ind in matches]

    atom_rank = list(Chem.CanonicalRankAtoms(lig_mol,breakTies=False))

    def find_unique_atoms(lst):
        seen = set()
        result = []
        for i, x in enumerate(lst):
            if x not in seen:
                result.append(i)
                seen.add(x)
        return result

    unique_atoms = find_unique_atoms(atom_rank)
    unique_cH = set(unique_atoms).intersection(set(cH_atoms))
    unique_cH = tuple(unique_cH)

    ts_rw = RWMol(ts)
    ts_rw_origin = copy(ts_rw)
    ts_ligand_pattern = Chem.MolFromSmarts("N1CCCC1")
    old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)

    # --- Create aligned maps --- #
    old_active_site = old_ring_match[0:3]

    maps = []
    for a in unique_cH:
        C_pos = lig_mol.GetAtomWithIdx(a)
        nbs = []
        for nb in C_pos.GetNeighbors():
            if nb.GetAtomicNum() == 1:
                pass # hydrogen
            else:
                nbs.append(nb.GetIdx())
        
        nbs.insert(1, C_pos.GetIdx())
        
        map = []
        for nb, aa in zip(nbs, old_active_site):
            map.append((nb, aa))
        maps.append(map)

    params = ETKDGv3()
    params.randomSeed = 0xF00D
    ts_mols = {}
    for map in maps:
        rpos = map[1][0]
        EmbedMolecule(lig_mol, params)
        ts_rw = Chem.RWMol(ts_rw_origin)
        rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

        # --- remove hydrogen from the reacting carbon --- #
        chosen_carbon_idx = rpos
        chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

        for nb in chosen_carbon.GetNeighbors():
            if nb.GetAtomicNum() == 1:  # hydrogen
                lig_mol_rw = RWMol(lig_mol)
                lig_mol_rw.RemoveAtom(nb.GetIdx())
                break

        # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
        n_pattern_full = Chem.MolFromSmiles("CN1CCCC1")
        n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

        atoms_to_remove = set()
        for idx in n_old_indices:
            atom = ts_rw.GetAtomWithIdx(idx)
            atoms_to_remove.add(idx)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                    atoms_to_remove.add(neighbor.GetIdx())

        for idx in sorted(atoms_to_remove, reverse=True):
            ts_rw.RemoveAtom(idx)

        frags = Chem.GetMolFrags(ts_rw, asMols=True)
        fixed_cat = fix_cat_frag(frags[0])
        fixed_Hs  = frags[1]

        ts_rw = RWMol(Chem.CombineMols(fixed_cat, fixed_Hs))
        offset = ts_rw.GetNumAtoms()
        reactive_H1 = offset - 1
        reactive_H2 = offset - 2

        ts_combined = Chem.CombineMols(ts_rw, lig_mol_rw)
        ts_rw_combined = RWMol(ts_combined)

        offset = ts_rw.GetNumAtoms()
        reactive_C = rpos + offset

        cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
        B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
        N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]

        B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
        Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

        ts_rw_combined.AddBond(B_cat_idx, reactive_C, Chem.BondType.SINGLE)
        
        atom_indices_to_keep = [B_cat_idx, N_cat_idx]
        atom_indices_to_keep.extend(Hs_on_B)
        atom_indices_to_keep.extend([reactive_H1, reactive_H2, reactive_C])

        if embed_ready:
            ts_rw_combined.AddBond(B_cat_idx, reactive_H1, Chem.BondType.ZERO)

        mol_name = get_molecule_name(ligand_smiles)
        ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep)
    
    return ts_mols

In [298]:
ts3_mols = transformer_ts3(ts_guess_struct=ts_guess_struct, embed_ready=True)

In [299]:
MolTo3DGrid(list(ts3_mols.values())[0][0])