In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
from copy import copy

from tooltoad.vis import MolTo3DGrid
from tooltoad.chemutils import ac2mol

from rdkit import Chem
from rdkit.Chem import rdDetermineBonds, rdMolAlign
from rdkit.Chem.AllChem import ETKDGv3, EmbedMolecule
from rdkit.Chem.rdchem import RWMol

from frust.utils.mols import fix_pin_frag, get_molecule_name
from frust.embedder import embed_ts

In [17]:
def transformer_ts4(
    ligand_smiles="C1=CC=CO1",
    ts_guess_struct="../structures/ts4.xyz",
    bonds_to_remove = [(11,23)],
    pre_name="TS4",
    embed_ready=True,
):
    # --- Read TS Guess Structure --- #
    try:
        with open(ts_guess_struct, 'r') as file:
            xyz_block = file.read()
    except FileNotFoundError:
        print(f"Error: Transition state structure file not found: {ts_guess_struct}")
        raise
    except PermissionError:
        print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
        raise
    except IOError as e:
        print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
        raise

    # --- Determine Connectivity --- #
    from rdkit import Chem
    ts = Chem.MolFromXYZBlock(xyz_block)
    rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)
    ts_rw = RWMol(ts)

    bonds_to_remove = bonds_to_remove
    for bond in bonds_to_remove:
        ts_rw.RemoveBond(bond[0], bond[1])
    ts_rw_origin = copy(ts_rw)

    # --- Find ligand in guess ts structure --- #
    ts_ligand_pattern = Chem.MolFromSmarts("S1CCCC1")
    old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)  # e.g. (5,6,7,8,9)

    # --- Find unique positions and check that they are valid cH --- #
    lig_mol = Chem.MolFromSmiles(ligand_smiles)
    lig_mol = Chem.AddHs(lig_mol)

    cH_patt = Chem.MolFromSmarts('[cH]')
    matches = lig_mol.GetSubstructMatches(cH_patt)
    cH_atoms = [ind[0] for ind in matches]

    atom_rank = list(Chem.CanonicalRankAtoms(lig_mol, breakTies=False))

    def find_unique_atoms(lst):
        seen = set()
        result = []
        for i, x in enumerate(lst):
            if x not in seen:
                result.append(i)
                seen.add(x)
        return result

    unique_atoms = find_unique_atoms(atom_rank)
    unique_cH = set(unique_atoms).intersection(set(cH_atoms))
    unique_cH = tuple(unique_cH)

    # --- Create aligned maps --- #
    old_active_site = old_ring_match[0:3]

    maps = []
    for a in unique_cH:
        C_pos = lig_mol.GetAtomWithIdx(a)
        nbs = []
        for nb in C_pos.GetNeighbors():
            if nb.GetAtomicNum() == 1:
                pass # hydrogen
            else:
                nbs.append(nb.GetIdx())
        
        nbs.insert(1, C_pos.GetIdx())
        
        map = []
        for nb, aa in zip(nbs, old_active_site):
            map.append((nb, aa))
        maps.append(map)

    # --- Loop through each map a.k.a reactive position and create the molecule --- #
    params = ETKDGv3()
    params.randomSeed = 0xF00D  # Use any integer seed
    ts_mols = {}

    lig_mol_original = copy(lig_mol)

    for map in maps:
        lig_mol = lig_mol_original
        rpos = map[1][0]
        EmbedMolecule(lig_mol, params)
        ts_rw = Chem.RWMol(ts_rw_origin)
        rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

        # --- remove hydrogen from the reacting carbon --- #
        chosen_carbon_idx = rpos
        chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

        for nb in chosen_carbon.GetNeighbors():
            if nb.GetAtomicNum() == 1:  # hydrogen
                lig_mol_rw = RWMol(lig_mol)
                lig_mol_rw.RemoveAtom(nb.GetIdx())
                lig_mol = lig_mol_rw.GetMol()
                break

        # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
        n_pattern_full = Chem.MolFromSmiles("C1CCCS1")
        n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

        atoms_to_remove = set()
        for idx in n_old_indices:
            atom = ts_rw.GetAtomWithIdx(idx)
            atoms_to_remove.add(idx)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                    atoms_to_remove.add(neighbor.GetIdx())

        for idx in sorted(atoms_to_remove, reverse=True):
            ts_rw.RemoveAtom(idx)

        frags = Chem.GetMolFrags(ts_rw, asMols=True)

        rdDetermineBonds.DetermineBonds(frags[0])
        frag1 = fix_pin_frag(frags[1])

        ts_rw = RWMol(Chem.CombineMols(frags[0], frag1))

        # --- Combine ligand and catalyst, add temporary bonds, and set temporary formal charges ---
        ts_combined = Chem.CombineMols(ts_rw, lig_mol)
        ts_rw_combined = RWMol(ts_combined)

        offset = ts_rw.GetNumAtoms()
        reactive_C = rpos + offset

        cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
        B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
        N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]
        B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
        Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

        pin_pat = Chem.MolFromSmarts('[B]1OC(C(O1)(C)C)(C)C')
        B_pin_idx = ts_rw_combined.GetSubstructMatches(pin_pat)[0][0]
        #B_pin_nbs = ts_rw_combined.GetAtomWithIdx(B_pin_idx).GetNeighbors()
        #H_pin_idx = B_pin_nbs[1].GetIdx()

        ts_rw_combined.AddBond(reactive_C, B_pin_idx, Chem.BondType.SINGLE)

        atom_indices_to_keep = [B_cat_idx, N_cat_idx]
        atom_indices_to_keep.extend(Hs_on_B)
        atom_indices_to_keep.extend([B_pin_idx])
        atom_indices_to_keep.extend([reactive_C])

        if embed_ready:
            pass
            ts_rw_combined.AddBond(B_cat_idx, B_pin_idx, Chem.BondType.SINGLE)
            b_pin_atom = ts_rw_combined.GetAtomWithIdx(B_pin_idx)
            b_cat_atom = ts_rw_combined.GetAtomWithIdx(B_cat_idx)
            b_pin_atom.SetFormalCharge(2)
            b_cat_atom.SetFormalCharge(2)
            
            # c_atom = ts_rw_combined.GetAtomWithIdx(reactive_C)
            # c_atom.SetFormalCharge(0)

        mol_name = get_molecule_name(ligand_smiles)
        ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep, ligand_smiles)

    return ts_mols

ts_dict = transformer_ts4(embed_ready=True)

In [18]:
ts_dict

{'TS4(furan_rpos(0))': (<rdkit.Chem.rdchem.RWMol at 0x168630630>,
  [11, 10, 12, 13, 37, 43],
  'C1=CC=CO1'),
 'TS4(furan_rpos(1))': (<rdkit.Chem.rdchem.RWMol at 0x168630a90>,
  [11, 10, 12, 13, 37, 44],
  'C1=CC=CO1')}

In [19]:
MolTo3DGrid(list(ts_dict.values())[0][0], kekulize=True, show_charges=False)

In [20]:
embeds = embed_ts(ts_dict, ts_type="TS4", n_confs=2, optimize=True)

TS4 embed
Embedded 2 conformers on atom 43
TS4 embed
Embedded 2 conformers on atom 44


In [21]:
MolTo3DGrid(list(embeds.values())[1][0], show_charges=False)

In [22]:
from frust.stepper import Stepper

step = Stepper(["C1=CC=CO1"], step_type="TS4", save_output_dir=False)
df0 = step.build_initial_df(embeds)
df0

2025-07-17 13:48:31 INFO  frust.stepper: Working dir: .


Unnamed: 0,custom_name,ligand_name,rpos,constraint_atoms,cid,smiles,atoms,coords_embedded,energy_uff
0,TS4(furan_rpos(0)),furan,0,"[11, 10, 12, 13, 37, 43]",0,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.876582655183915, 1.15256349700219, -0.3497...",2140.946474
1,TS4(furan_rpos(0)),furan,0,"[11, 10, 12, 13, 37, 43]",1,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.8205884376302985, -0.8577919014765315, 0.2...",2386.51642
2,TS4(furan_rpos(1)),furan,1,"[11, 10, 12, 13, 37, 44]",0,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.90808230470615, 1.0050871180228258, -0.213...",2326.487932
3,TS4(furan_rpos(1)),furan,1,"[11, 10, 12, 13, 37, 44]",1,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.9600286726629146, 0.8972117566006804, 0.00...",3086.964095


In [23]:
df1 = step.xtb(df0, options={"gfnff": None, "opt": None}, constraint=True)
df1

Step type TS4
2025-07-17 13:48:31 INFO  frust.stepper: [xtb-gfnff-opt] row 0 (TS4(furan_rpos(0)))…
2025-07-17 13:48:31 INFO  frust.stepper: [xtb-gfnff-opt] row 1 (TS4(furan_rpos(0)))…
2025-07-17 13:48:31 INFO  frust.stepper: [xtb-gfnff-opt] row 2 (TS4(furan_rpos(1)))…
2025-07-17 13:48:31 INFO  frust.stepper: [xtb-gfnff-opt] row 3 (TS4(furan_rpos(1)))…


Unnamed: 0,custom_name,ligand_name,rpos,constraint_atoms,cid,smiles,atoms,coords_embedded,energy_uff,xtb-gfnff-opt-electronic_energy,xtb-gfnff-opt-normal_termination,xtb-gfnff-opt-opt_coords
0,TS4(furan_rpos(0)),furan,0,"[11, 10, 12, 13, 37, 43]",0,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.876582655183915, 1.15256349700219, -0.3497...",2140.946474,-8.479152,True,"[[3.1894009018721, 0.83536373955833, -0.231875..."
1,TS4(furan_rpos(0)),furan,0,"[11, 10, 12, 13, 37, 43]",1,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.8205884376302985, -0.8577919014765315, 0.2...",2386.51642,-8.481184,True,"[[3.14378235713521, -0.70817703417045, -0.1026..."
2,TS4(furan_rpos(1)),furan,1,"[11, 10, 12, 13, 37, 44]",0,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.90808230470615, 1.0050871180228258, -0.213...",2326.487932,-8.469095,True,"[[3.17566957196959, 0.83378037774196, -0.14237..."
3,TS4(furan_rpos(1)),furan,1,"[11, 10, 12, 13, 37, 44]",1,C1=CC=CO1,"[C, C, C, C, C, C, H, H, H, H, N, B, H, H, C, ...","[(2.9600286726629146, 0.8972117566006804, 0.00...",3086.964095,-8.4214,True,"[[3.10759659045436, 0.84385826497559, 0.036353..."


In [24]:
idx = 0
atoms = df1["atoms"].iloc[idx]
coords1 = df1["coords_embedded"].iloc[idx]
coords2 = df1["xtb-gfnff-opt-opt_coords"].iloc[idx]
all_coords = [coords1, coords2]
all_mols = [ac2mol(atoms, c) for c in all_coords]
MolTo3DGrid(all_mols, legends=['embed', 'xtb-opt'])

In [25]:
from tooltoad.chemutils import xyz2mol

with open("../structures/ts4.xyz") as file:
    xyzblock = file.read()

mol = xyz2mol(xyzblock)

MolTo3DGrid(mol)