In [1]:
from copy import copy
from rdkit import Chem
from rdkit.Chem import rdDetermineBonds, rdMolAlign
from rdkit.Chem.AllChem import ETKDGv3, EmbedMolecule
from rdkit.Chem.rdchem import RWMol, Mol
from frust.utils.mols import get_molecule_name
from tooltoad.vis import MolTo3DGrid

In [2]:
ligand_smiles="CC([Si](N1C=CC=C1)(C(C)C)C(C)C)C"
ts_guess_struct="../structures/ts2_guess.xyz"
bonds_to_remove = [(59, 62), (52, 62), (52, 54)]
constraint_atoms=[54, 52, 59, 62] # TS2: H, B1, B2, C ---> B1 = Catalyst, B2 = Pinacolborane
H_idx=54
pre_name="TS"
embed_ready=False

In [3]:
import math
from rdkit import Chem
from rdkit.Chem import rdchem


BH_LEN = 1.18   # Å – average B–H bond length

def _unit(vec):
    """return unit vector (or None if zero length)"""
    norm = math.sqrt(vec.x ** 2 + vec.y ** 2 + vec.z ** 2)
    if norm < 1e-6:
        return None
    return vec.__class__(vec.x / norm, vec.y / norm, vec.z / norm)


def _fix_pin_frag(frag: Chem.Mol) -> Chem.Mol:
    """
    • converts B=O double bonds → single, neutralises B/O  
    • adds a B–H 1.18 Å away from B but does **not** move any other atom
    """
    rw   = Chem.RWMol(frag)
    conf = rw.GetConformer()

    # ---- locate boron ----
    b_idx = next(a.GetIdx() for a in rw.GetAtoms() if a.GetAtomicNum() == 5)
    boron = rw.GetAtomWithIdx(b_idx)

    # ---- make B-O single & neutral ----
    for nb in boron.GetNeighbors():
        bond = rw.GetBondBetweenAtoms(b_idx, nb.GetIdx())
        if bond.GetBondType() == rdchem.BondType.DOUBLE:
            bond.SetBondType(rdchem.BondType.SINGLE)
        if nb.GetAtomicNum() == 8:
            nb.SetFormalCharge(0)
    boron.SetFormalCharge(0)

    # ---- add the missing H (if needed) ----
    if boron.GetTotalDegree() < 3:
        # 1) choose a direction opposite to the average of B→heavy-neighbor vectors
        b_pos = conf.GetAtomPosition(b_idx)
        acc   = b_pos.__class__(0.0, 0.0, 0.0)
        heavy_cnt = 0
        for nb in boron.GetNeighbors():
            if nb.GetAtomicNum() > 1:        # O or C
                n_pos = conf.GetAtomPosition(nb.GetIdx())
                acc.x += n_pos.x - b_pos.x
                acc.y += n_pos.y - b_pos.y
                acc.z += n_pos.z - b_pos.z
                heavy_cnt += 1
        # average & flip
        acc.x *= -1.0 / heavy_cnt
        acc.y *= -1.0 / heavy_cnt
        acc.z *= -1.0 / heavy_cnt
        direction = _unit(acc) or b_pos.__class__(1.0, 0.0, 0.0)  # fallback

        # 2) place H at BH_LEN along that direction
        h_pos = b_pos.__class__(
            b_pos.x + direction.x * BH_LEN,
            b_pos.y + direction.y * BH_LEN,
            b_pos.z + direction.z * BH_LEN,
        )

        # 3) add the atom & bond
        h_idx = rw.AddAtom(Chem.Atom(1))
        rw.AddBond(b_idx, h_idx, rdchem.BondType.SINGLE)
        conf.SetAtomPosition(h_idx, h_pos)

    # ---- sanitize & return ----
    Chem.SanitizeMol(rw)
    rw.RemoveAtom(h_idx)

    return rw.GetMol()


from rdkit import Chem
from rdkit.Chem import rdDetermineBonds
from rdkit.Geometry.rdGeometry import Point3D
import numpy as np

def _fix_cat_frag(mol: Chem.Mol, bh_len: float = 1.19) -> Chem.Mol:
    """
    • Finds the single B atom in *mol*
    • Adds one H at a standard B–H distance without moving any existing atoms
    • Neutralises B if it was −1
    • Rebuilds all connectivity from the final 3D coords (rings, valence, aromaticity)
    """
    rw   = Chem.RWMol(mol)
    conf = rw.GetConformer()

    # locate boron and neutralise if needed
    b_idx = next(a.GetIdx() for a in rw.GetAtoms() if a.GetAtomicNum() == 5)
    boron = rw.GetAtomWithIdx(b_idx)
    if boron.GetFormalCharge() == -1:
        boron.SetFormalCharge(0)

    # pick direction roughly opposite its neighbours
    bpos   = np.array(conf.GetAtomPosition(b_idx))
    neighs = [n.GetIdx() for n in boron.GetNeighbors()]
    if neighs:
        pts = np.array([conf.GetAtomPosition(i) for i in neighs])
        v   = bpos - pts.mean(axis=0)
    else:
        v   = np.array([1.0, 0.0, 0.0])
    v /= np.linalg.norm(v) or 1.0

    # add the H at ~bh_len Å from B
    h_pos = bpos + bh_len * v
    h_idx = rw.AddAtom(Chem.Atom(1))

    rw.AddBond(b_idx, h_idx, Chem.BondType.SINGLE)
    conf.SetAtomPosition(h_idx, Point3D(*h_pos))

    # rebuild connectivity purely from 3D coords
    xyz = Chem.MolToXYZBlock(rw)
    mol2 = Chem.MolFromXYZBlock(xyz)
    rdDetermineBonds.DetermineBonds(mol2, useVdw=True)
    Chem.SanitizeMol(mol2)

    rw = RWMol(mol2)
    rw.RemoveAtom(h_idx)

    return rw.GetMol()

In [4]:
  # --- Read TS Guess Structure --- #
try:
    with open(ts_guess_struct, 'r') as file:
        xyz_block = file.read()
except FileNotFoundError:
    print(f"Error: Transition state structure file not found: {ts_guess_struct}")
    raise
except PermissionError:
    print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
    raise
except IOError as e:
    print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
    raise
except Exception as e:
    print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
    raise

# --- Determine Connectivity --- #
ts = Chem.MolFromXYZBlock(xyz_block)
rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)
ts_rw = RWMol(ts)
MolTo3DGrid(ts_rw)

In [None]:
bonds_to_remove = bonds_to_remove
for bond in bonds_to_remove:
    ts_rw.RemoveBond(bond[0], bond[1])
ts_rw_origin = copy(ts_rw)

MolTo3DGrid(ts_rw, cell_size=(600,600))

: 

In [123]:
# --- Find ligand in guess ts structure --- #
ts_ligand_pattern = Chem.MolFromSmarts("N1CCCC1")
old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)  # e.g. (5,6,7,8,9)

# --- Find unique positions and check that they are valid cH --- #
lig_mol = Chem.MolFromSmiles(ligand_smiles)
lig_mol = Chem.AddHs(lig_mol)

cH_patt = Chem.MolFromSmarts('[cH]')
matches = lig_mol.GetSubstructMatches(cH_patt)
cH_atoms = [ind[0] for ind in matches]

atom_rank = list(Chem.CanonicalRankAtoms(lig_mol,breakTies=False))

def find_unique_atoms(lst):
    seen = set()
    result = []
    for i, x in enumerate(lst):
        if x not in seen:
            result.append(i)
            seen.add(x)
    return result

unique_atoms = find_unique_atoms(atom_rank)
unique_cH = set(unique_atoms).intersection(set(cH_atoms))
unique_cH = tuple(unique_cH)

# --- Create aligned maps --- #
old_active_site = old_ring_match[0:3]

maps = []
for a in unique_cH:
    C_pos = lig_mol.GetAtomWithIdx(a)
    nbs = []
    for nb in C_pos.GetNeighbors():
        if nb.GetAtomicNum() == 1:
            pass # hydrogen
        else:
            nbs.append(nb.GetIdx())
    
    nbs.insert(1, C_pos.GetIdx())
    
    map = []
    for nb, aa in zip(nbs, old_active_site):
        map.append((nb, aa))
    maps.append(map)

# --- Loop through each map a.k.a reactive position and create the molecule --- #
params = ETKDGv3()
params.randomSeed = 0xF00D  # Use any integer seed
ts_mols = {}
for map in maps:
    rpos = map[1][0]
    EmbedMolecule(lig_mol, params)
    ts_rw = Chem.RWMol(ts_rw_origin)
    rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

    # --- remove hydrogen from the reacting carbon --- #
    chosen_carbon_idx = rpos
    chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

    for nb in chosen_carbon.GetNeighbors():
        if nb.GetAtomicNum() == 1:  # hydrogen
            lig_mol_rw = RWMol(lig_mol)
            lig_mol_rw.RemoveAtom(nb.GetIdx())
            lig_mol = lig_mol_rw.GetMol()
            break

    # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
    n_pattern_full = Chem.MolFromSmiles("CN1CCCC1")
    n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

    atoms_to_remove = set()
    for idx in n_old_indices:
        atom = ts_rw.GetAtomWithIdx(idx)
        atoms_to_remove.add(idx)
        for neighbor in atom.GetNeighbors():
            if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                atoms_to_remove.add(neighbor.GetIdx())

    for idx in sorted(atoms_to_remove, reverse=True):
        ts_rw.RemoveAtom(idx)


    atom_idx_to_remove = H_idx
    atom_to_remove = ts_rw.GetAtomWithIdx(atom_idx_to_remove)
    atom_symbol = atom_to_remove.GetSymbol()
    atom_coords = ts_rw.GetConformer().GetAtomPosition(atom_idx_to_remove)

    ts_rw.RemoveAtom(atom_idx_to_remove)
    rdDetermineBonds.DetermineBonds(ts_rw)

    frags = Chem.GetMolFrags(ts_rw, asMols=True)

    fixed_cat = _fix_cat_frag(frags[0])
    fixed_pin = _fix_pin_frag(frags[1])

    ts_rw = RWMol(Chem.CombineMols(fixed_cat, fixed_pin))

    #  --- Add reactive H back --- #
    new_atom_idx = ts_rw.AddAtom(Chem.Atom(atom_symbol))
    ts_rw.GetConformer().SetAtomPosition(new_atom_idx, atom_coords)          

    # --- Combine ligand and catalyst, add temporary bonds, and set temporary formal charges ---
    ts_combined = Chem.CombineMols(ts_rw, lig_mol)
    ts_rw_combined = RWMol(ts_combined)

    offset = ts_rw.GetNumAtoms()
    reactive_H = offset - 1 # the reactive H is the offset - 1, because it was the last atom added to the mol.
    reactive_C = rpos + offset

    cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
    B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
    N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]
    B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
    Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

    pin_pat = Chem.MolFromSmarts('[B]1OC(C(O1)(C)C)(C)C')
    B_pin_idx = ts_rw_combined.GetSubstructMatches(pin_pat)[0][0]
    print(B_pin_idx)

    atom_indices_to_keep = [B_cat_idx, N_cat_idx]
    atom_indices_to_keep.extend(Hs_on_B)
    atom_indices_to_keep.extend([reactive_H, reactive_C])
    #atom_indices_to_keep = [10, 11, 39, 40, reactive_H, reactive_C]

    embed_ready = True

    if embed_ready:
        ts_rw_combined.AddBond(B_cat_idx, reactive_C, Chem.BondType.ZERO)
        ts_rw_combined.AddBond(B_cat_idx, reactive_H, Chem.BondType.SINGLE)
        ts_rw_combined.AddBond(B_cat_idx, B_pin_idx, Chem.BondType.SINGLE)
        b_atom = ts_rw_combined.GetAtomWithIdx(10)
        b_atom.SetFormalCharge(2)
        c_atom = ts_rw_combined.GetAtomWithIdx(reactive_C)
        c_atom.SetFormalCharge(-1)

    # rpos = lig_match.index(rpos) # reset the index for rpos.

    mol_name = get_molecule_name(ligand_smiles)
    ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep)

58
58


In [124]:
mol = ts_mols.get("TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(5))")[0]
MolTo3DGrid(mol, cell_size=(500,500))

In [125]:
fix_pinacol = _fix_pin_frag(frags[1])
MolTo3DGrid(fix_pinacol)

In [126]:
_fix_cat_frag = _fix_cat_frag(frags[0])
MolTo3DGrid(_fix_cat_frag, cell_size=(600,600))

In [127]:
ts_frags_combined = Chem.CombineMols(_fix_cat_frag, fix_pinacol)
MolTo3DGrid(ts_frags_combined)

In [128]:
ts_mols

{'TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(4))': (<rdkit.Chem.rdchem.RWMol at 0x1098df600>,
  [38, 10, 39, 61, 66]),
 'TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(5))': (<rdkit.Chem.rdchem.RWMol at 0x13a1ba250>,
  [38, 10, 39, 61, 67])}

In [129]:
MolTo3DGrid(ts_mols.get('TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(4))')[0])

## The Final Function

In [135]:
def transformer_ts2(
    ligand_smiles="CC([Si](N1C=CC=C1)(C(C)C)C(C)C)C",
    ts_guess_struct="../structures/ts2_guess.xyz",
    bonds_to_remove = [(59, 62), (52, 62), (52, 54)],
    constraint_atoms=[54, 52, 59, 62], # TS2: H, B1, B2, C ---> B1 = Catalyst, B2 = Pinacolborane
    H_idx=54,
    pre_name="TS2",
    embed_ready=False,
):
    # --- Read TS Guess Structure --- #
    try:
        with open(ts_guess_struct, 'r') as file:
            xyz_block = file.read()
    except FileNotFoundError:
        print(f"Error: Transition state structure file not found: {ts_guess_struct}")
        raise
    except PermissionError:
        print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
        raise
    except IOError as e:
        print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
        raise

    # --- Determine Connectivity --- #
    from rdkit import Chem
    ts = Chem.MolFromXYZBlock(xyz_block)
    rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)
    ts_rw = RWMol(ts)

    BH_LEN = 1.18   # Å – average B–H bond length

    def _unit(vec):
        """return unit vector (or None if zero length)"""
        norm = math.sqrt(vec.x ** 2 + vec.y ** 2 + vec.z ** 2)
        if norm < 1e-6:
            return None
        return vec.__class__(vec.x / norm, vec.y / norm, vec.z / norm)


    def _fix_pin_frag(frag: Chem.Mol) -> Chem.Mol:
        """
        • converts B=O double bonds → single, neutralises B/O  
        • adds a B–H 1.18 Å away from B but does **not** move any other atom
        """
        rw   = Chem.RWMol(frag)
        conf = rw.GetConformer()

        # ---- locate boron ----
        b_idx = next(a.GetIdx() for a in rw.GetAtoms() if a.GetAtomicNum() == 5)
        boron = rw.GetAtomWithIdx(b_idx)

        # ---- make B-O single & neutral ----
        for nb in boron.GetNeighbors():
            bond = rw.GetBondBetweenAtoms(b_idx, nb.GetIdx())
            if bond.GetBondType() == rdchem.BondType.DOUBLE:
                bond.SetBondType(rdchem.BondType.SINGLE)
            if nb.GetAtomicNum() == 8:
                nb.SetFormalCharge(0)
        boron.SetFormalCharge(0)

        # ---- add the missing H (if needed) ----
        if boron.GetTotalDegree() < 3:
            # 1) choose a direction opposite to the average of B→heavy-neighbor vectors
            b_pos = conf.GetAtomPosition(b_idx)
            acc   = b_pos.__class__(0.0, 0.0, 0.0)
            heavy_cnt = 0
            for nb in boron.GetNeighbors():
                if nb.GetAtomicNum() > 1:        # O or C
                    n_pos = conf.GetAtomPosition(nb.GetIdx())
                    acc.x += n_pos.x - b_pos.x
                    acc.y += n_pos.y - b_pos.y
                    acc.z += n_pos.z - b_pos.z
                    heavy_cnt += 1
            # average & flip
            acc.x *= -1.0 / heavy_cnt
            acc.y *= -1.0 / heavy_cnt
            acc.z *= -1.0 / heavy_cnt
            direction = _unit(acc) or b_pos.__class__(1.0, 0.0, 0.0)  # fallback

            # 2) place H at BH_LEN along that direction
            h_pos = b_pos.__class__(
                b_pos.x + direction.x * BH_LEN,
                b_pos.y + direction.y * BH_LEN,
                b_pos.z + direction.z * BH_LEN,
            )

            # 3) add the atom & bond
            h_idx = rw.AddAtom(Chem.Atom(1))
            rw.AddBond(b_idx, h_idx, rdchem.BondType.SINGLE)
            conf.SetAtomPosition(h_idx, h_pos)

        # ---- sanitize & return ----
        Chem.SanitizeMol(rw)
        rw.RemoveAtom(h_idx)

        return rw.GetMol()

    def _fix_cat_frag(mol: Chem.Mol, bh_len: float = 1.19) -> Chem.Mol:
        """
        • Finds the single B atom in *mol*
        • Adds one H at a standard B–H distance without moving any existing atoms
        • Neutralises B if it was −1
        • Rebuilds all connectivity from the final 3D coords (rings, valence, aromaticity)
        """
        rw   = Chem.RWMol(mol)
        conf = rw.GetConformer()

        # locate boron and neutralise if needed
        b_idx = next(a.GetIdx() for a in rw.GetAtoms() if a.GetAtomicNum() == 5)
        boron = rw.GetAtomWithIdx(b_idx)
        if boron.GetFormalCharge() == -1:
            boron.SetFormalCharge(0)

        # pick direction roughly opposite its neighbours
        bpos   = np.array(conf.GetAtomPosition(b_idx))
        neighs = [n.GetIdx() for n in boron.GetNeighbors()]
        if neighs:
            pts = np.array([conf.GetAtomPosition(i) for i in neighs])
            v   = bpos - pts.mean(axis=0)
        else:
            v   = np.array([1.0, 0.0, 0.0])
        v /= np.linalg.norm(v) or 1.0

        # add the H at ~bh_len Å from B
        h_pos = bpos + bh_len * v
        h_idx = rw.AddAtom(Chem.Atom(1))

        rw.AddBond(b_idx, h_idx, Chem.BondType.SINGLE)
        conf.SetAtomPosition(h_idx, Point3D(*h_pos))

        # rebuild connectivity purely from 3D coords
        xyz = Chem.MolToXYZBlock(rw)
        mol2 = Chem.MolFromXYZBlock(xyz)
        rdDetermineBonds.DetermineBonds(mol2, useVdw=True)
        Chem.SanitizeMol(mol2)

        rw = RWMol(mol2)
        rw.RemoveAtom(h_idx)

        return rw.GetMol()
    
    bonds_to_remove = bonds_to_remove
    for bond in bonds_to_remove:
        ts_rw.RemoveBond(bond[0], bond[1])
    ts_rw_origin = copy(ts_rw)

    # --- Find ligand in guess ts structure --- #
    ts_ligand_pattern = Chem.MolFromSmarts("N1CCCC1")
    old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)  # e.g. (5,6,7,8,9)

    # --- Find unique positions and check that they are valid cH --- #
    lig_mol = Chem.MolFromSmiles(ligand_smiles)
    lig_mol = Chem.AddHs(lig_mol)

    cH_patt = Chem.MolFromSmarts('[cH]')
    matches = lig_mol.GetSubstructMatches(cH_patt)
    cH_atoms = [ind[0] for ind in matches]

    atom_rank = list(Chem.CanonicalRankAtoms(lig_mol, breakTies=False))

    def find_unique_atoms(lst):
        seen = set()
        result = []
        for i, x in enumerate(lst):
            if x not in seen:
                result.append(i)
                seen.add(x)
        return result

    unique_atoms = find_unique_atoms(atom_rank)
    unique_cH = set(unique_atoms).intersection(set(cH_atoms))
    unique_cH = tuple(unique_cH)

    # --- Create aligned maps --- #
    old_active_site = old_ring_match[0:3]

    maps = []
    for a in unique_cH:
        C_pos = lig_mol.GetAtomWithIdx(a)
        nbs = []
        for nb in C_pos.GetNeighbors():
            if nb.GetAtomicNum() == 1:
                pass # hydrogen
            else:
                nbs.append(nb.GetIdx())
        
        nbs.insert(1, C_pos.GetIdx())
        
        map = []
        for nb, aa in zip(nbs, old_active_site):
            map.append((nb, aa))
        maps.append(map)

    # --- Loop through each map a.k.a reactive position and create the molecule --- #
    params = ETKDGv3()
    params.randomSeed = 0xF00D  # Use any integer seed
    ts_mols = {}
    for map in maps:
        rpos = map[1][0]
        EmbedMolecule(lig_mol, params)
        ts_rw = Chem.RWMol(ts_rw_origin)
        rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

        # --- remove hydrogen from the reacting carbon --- #
        chosen_carbon_idx = rpos
        chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

        for nb in chosen_carbon.GetNeighbors():
            if nb.GetAtomicNum() == 1:  # hydrogen
                lig_mol_rw = RWMol(lig_mol)
                lig_mol_rw.RemoveAtom(nb.GetIdx())
                lig_mol = lig_mol_rw.GetMol()
                break

        # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
        n_pattern_full = Chem.MolFromSmiles("CN1CCCC1")
        n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

        atoms_to_remove = set()
        for idx in n_old_indices:
            atom = ts_rw.GetAtomWithIdx(idx)
            atoms_to_remove.add(idx)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                    atoms_to_remove.add(neighbor.GetIdx())

        for idx in sorted(atoms_to_remove, reverse=True):
            ts_rw.RemoveAtom(idx)

        atom_idx_to_remove = H_idx
        atom_to_remove = ts_rw.GetAtomWithIdx(atom_idx_to_remove)
        atom_symbol = atom_to_remove.GetSymbol()
        atom_coords = ts_rw.GetConformer().GetAtomPosition(atom_idx_to_remove)

        ts_rw.RemoveAtom(atom_idx_to_remove)
        rdDetermineBonds.DetermineBonds(ts_rw)

        frags = Chem.GetMolFrags(ts_rw, asMols=True)

        fixed_cat = _fix_cat_frag(frags[0])
        fixed_pin = _fix_pin_frag(frags[1])

        ts_rw = RWMol(Chem.CombineMols(fixed_cat, fixed_pin))

        #  --- Add reactive H back --- #
        new_atom_idx = ts_rw.AddAtom(Chem.Atom(atom_symbol))
        ts_rw.GetConformer().SetAtomPosition(new_atom_idx, atom_coords)          

        # --- Combine ligand and catalyst, add temporary bonds, and set temporary formal charges ---
        ts_combined = Chem.CombineMols(ts_rw, lig_mol)
        ts_rw_combined = RWMol(ts_combined)

        offset = ts_rw.GetNumAtoms()
        reactive_H = offset - 1 # the reactive H is the offset - 1, because it was the last atom added to the mol.
        reactive_C = rpos + offset

        cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
        B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
        N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]
        B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
        Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

        pin_pat = Chem.MolFromSmarts('[B]1OC(C(O1)(C)C)(C)C')
        B_pin_idx = ts_rw_combined.GetSubstructMatches(pin_pat)[0][0]

        atom_indices_to_keep = [B_cat_idx, N_cat_idx]
        atom_indices_to_keep.extend(Hs_on_B)
        atom_indices_to_keep.extend([B_pin_idx])
        atom_indices_to_keep.extend([reactive_H, reactive_C])

        if embed_ready:
            ts_rw_combined.AddBond(B_cat_idx, reactive_C, Chem.BondType.ZERO)
            ts_rw_combined.AddBond(B_cat_idx, reactive_H, Chem.BondType.SINGLE)
            ts_rw_combined.AddBond(B_cat_idx, B_pin_idx, Chem.BondType.SINGLE)
            b_atom = ts_rw_combined.GetAtomWithIdx(10)
            b_atom.SetFormalCharge(2)
            c_atom = ts_rw_combined.GetAtomWithIdx(reactive_C)
            c_atom.SetFormalCharge(-1)

        mol_name = get_molecule_name(ligand_smiles)
        ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep)

    return ts_mols

ts2_mols = transformer_ts2(embed_ready=False)

In [136]:
ts2_mols

{'TS2(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(4))': (<rdkit.Chem.rdchem.RWMol at 0x13acfd940>,
  [38, 10, 39, 58, 61, 66]),
 'TS2(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(5))': (<rdkit.Chem.rdchem.RWMol at 0x13acfa9d0>,
  [38, 10, 39, 58, 61, 67])}

In [138]:
MolTo3DGrid(ts2_mols.get("TS2(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(5))")[0])

## TS1 Function

In [82]:
def transformer_ts(
    ligand_smiles       = "CC([Si](N1C=CC=C1)(C(C)C)C(C)C)C",
    ts_guess_struct     = "ts1_guess.xyz",
    constraint_atoms    = [10, 11, 12, 41], # H, B, N, C
    H_idx               = 10,
    pre_name            = "TS",
    embed_ready         = True,
):
    
    # --- Read TS Guess Structure --- #
    try:
        with open(ts_guess_struct, 'r') as file:
            xyz_block = file.read()
    except FileNotFoundError:
        print(f"Error: Transition state structure file not found: {ts_guess_struct}")
        raise
    except PermissionError:
        print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
        raise
    except IOError as e:
        print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
        raise

    # --- Determine Connectivity --- #
    ts = Chem.MolFromXYZBlock(xyz_block)
    rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)
    ts_rw = RWMol(ts)

    # --- Remove Bonds --- #
    bonds_to_remove = [
        (constraint_atoms[0], constraint_atoms[3]),
        (constraint_atoms[0], constraint_atoms[2]),
        (constraint_atoms[1], constraint_atoms[3]),
    ]

    bonds_to_remove = bonds_to_remove # Finding these bonds might need to be automated.
    for bond in bonds_to_remove:
        ts_rw.RemoveBond(bond[0], bond[1])
    ts_rw_origin = copy(ts_rw)

    # --- Find ligand in guess ts structure --- #
    ts_ligand_pattern = Chem.MolFromSmarts("N1CCCC1")
    old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)  # e.g. (5,6,7,8,9)

    # --- Find unique positions and check that they are valid cH --- #
    lig_mol = Chem.MolFromSmiles(ligand_smiles)
    lig_mol = Chem.AddHs(lig_mol)

    cH_patt = Chem.MolFromSmarts('[cH]')
    matches = lig_mol.GetSubstructMatches(cH_patt)
    cH_atoms = [ind[0] for ind in matches]

    atom_rank = list(Chem.CanonicalRankAtoms(lig_mol,breakTies=False))
    
    def find_unique_atoms(lst):
        seen = set()
        result = []
        for i, x in enumerate(lst):
            if x not in seen:
                result.append(i)
                seen.add(x)
        return result

    unique_atoms = find_unique_atoms(atom_rank)
    unique_cH = set(unique_atoms).intersection(set(cH_atoms))
    unique_cH = tuple(unique_cH)

    # --- Create aligned maps --- #
    old_active_site = old_ring_match[0:3]

    maps = []
    for a in unique_cH:
        C_pos = lig_mol.GetAtomWithIdx(a)
        nbs = []
        for nb in C_pos.GetNeighbors():
            if nb.GetAtomicNum() == 1:
                pass # hydrogen
            else:
                nbs.append(nb.GetIdx())
        
        nbs.insert(1, C_pos.GetIdx())
        
        map = []
        for nb, aa in zip(nbs, old_active_site):
            map.append((nb, aa))
        maps.append(map)

    # --- Loop through each map a.k.a reactive position and create the molecule --- #
    params = ETKDGv3()
    params.randomSeed = 0xF00D  # Use any integer seed
    ts_mols = {}
    for map in maps:
        rpos = map[1][0]
        EmbedMolecule(lig_mol, params)
        ts_rw = Chem.RWMol(ts_rw_origin)
        rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

        # --- remove hydrogen from the reacting carbon --- #
        chosen_carbon_idx = rpos
        chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

        for nb in chosen_carbon.GetNeighbors():
            if nb.GetAtomicNum() == 1:  # hydrogen
                lig_mol_rw = RWMol(lig_mol)
                lig_mol_rw.RemoveAtom(nb.GetIdx())
                lig_mol = lig_mol_rw.GetMol()
                break

        # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
        n_pattern_full = Chem.MolFromSmiles("CN1CCCC1")
        n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

        atoms_to_remove = set()
        for idx in n_old_indices:
            atom = ts_rw.GetAtomWithIdx(idx)
            atoms_to_remove.add(idx)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                    atoms_to_remove.add(neighbor.GetIdx())

        for idx in sorted(atoms_to_remove, reverse=True):
            ts_rw.RemoveAtom(idx)

        atom_idx_to_remove = H_idx
        atom_to_remove = ts_rw.GetAtomWithIdx(atom_idx_to_remove)
        atom_symbol = atom_to_remove.GetSymbol()
        atom_coords = ts_rw.GetConformer().GetAtomPosition(atom_idx_to_remove)

        ts_rw.RemoveAtom(atom_idx_to_remove)

        rdDetermineBonds.DetermineBonds(ts_rw)

        #  --- Add reactive H back --- #
        new_atom_idx = ts_rw.AddAtom(Chem.Atom(atom_symbol))
        ts_rw.GetConformer().SetAtomPosition(new_atom_idx, atom_coords)          

        # --- Combine ligand and catalyst, add temporary bonds, and set temporary formal charges ---
        ts_combined = Chem.CombineMols(ts_rw, lig_mol)
        ts_rw_combined = RWMol(ts_combined)

        offset = ts_rw.GetNumAtoms()
        reactive_H = offset - 1 # the reactive H is the offset - 1, because it was the last atom added to the mol.
        reactive_C = rpos + offset
        atom_indices_to_keep = [10, 11, 39, 40, reactive_H, reactive_C]

        if embed_ready:
            ts_rw_combined.AddBond(10, reactive_C, Chem.BondType.ZERO)
            ts_rw_combined.AddBond(10, reactive_H, Chem.BondType.SINGLE)
            b_atom = ts_rw_combined.GetAtomWithIdx(10)
            b_atom.SetFormalCharge(2)
            c_atom = ts_rw_combined.GetAtomWithIdx(reactive_C)
            c_atom.SetFormalCharge(-1)

        # rpos = lig_match.index(rpos) # reset the index for rpos.

        mol_name = get_molecule_name(ligand_smiles)  
        ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep)
    
    return ts_mols

ts1_mols = transformer_ts(ts_guess_struct="../structures/ts1_guess.xyz", embed_ready=True)
#ts2_mols = transformer_ts(ts_guess_struct="../structures/ts2_guess.xyz", H_idx=54, bonds_to_remove=[(59, 62), (52, 62), (52, 54)], embed_ready=False)

In [83]:
ts_rw_combined.GetAtomWithIdx(10).GetNeighbors()[2].GetAtomicNum()

6

In [84]:
ts1_mols.get('TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(4))')

(<rdkit.Chem.rdchem.RWMol at 0x13a326430>, [10, 11, 39, 40, 41, 46])

In [None]:
MolTo3DGrid(ts1_mols.get('TS(tri(propan-2-yl)-pyrrol-1-ylsilane_rpos(4))')[0])