In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
from copy import copy

from tooltoad.vis import MolTo3DGrid

from rdkit import Chem
from rdkit.Chem import rdDetermineBonds, rdMolAlign
from rdkit.Chem.AllChem import ETKDGv3, EmbedMolecule
from rdkit.Chem.rdchem import RWMol

from frust.utils.mols import fix_pin_frag, get_molecule_name
from frust.embedder import embed_ts

In [25]:
def transformer_ts4(
    ligand_smiles="C1=CC=CO1",
    ts_guess_struct="../structures/ts4.xyz",
    bonds_to_remove = [(11,23)],
    pre_name="TS3",
    embed_ready=True,
):
    # --- Read TS Guess Structure --- #
    try:
        with open(ts_guess_struct, 'r') as file:
            xyz_block = file.read()
    except FileNotFoundError:
        print(f"Error: Transition state structure file not found: {ts_guess_struct}")
        raise
    except PermissionError:
        print(f"Error: Permission denied when accessing file: {ts_guess_struct}")
        raise
    except IOError as e:
        print(f"Error: Failed to read transition state structure file {ts_guess_struct}: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error loading transition state structure from {ts_guess_struct}: {e}")
        raise

    # --- Determine Connectivity --- #
    from rdkit import Chem
    ts = Chem.MolFromXYZBlock(xyz_block)
    rdDetermineBonds.DetermineConnectivity(ts, useVdw=True)
    ts_rw = RWMol(ts)

    bonds_to_remove = bonds_to_remove
    for bond in bonds_to_remove:
        ts_rw.RemoveBond(bond[0], bond[1])
    ts_rw_origin = copy(ts_rw)

    # --- Find ligand in guess ts structure --- #
    ts_ligand_pattern = Chem.MolFromSmarts("S1CCCC1")
    old_ring_match    = ts_rw.GetSubstructMatch(ts_ligand_pattern)  # e.g. (5,6,7,8,9)

    # --- Find unique positions and check that they are valid cH --- #
    lig_mol = Chem.MolFromSmiles(ligand_smiles)
    lig_mol = Chem.AddHs(lig_mol)

    cH_patt = Chem.MolFromSmarts('[cH]')
    matches = lig_mol.GetSubstructMatches(cH_patt)
    cH_atoms = [ind[0] for ind in matches]

    atom_rank = list(Chem.CanonicalRankAtoms(lig_mol, breakTies=False))

    def find_unique_atoms(lst):
        seen = set()
        result = []
        for i, x in enumerate(lst):
            if x not in seen:
                result.append(i)
                seen.add(x)
        return result

    unique_atoms = find_unique_atoms(atom_rank)
    unique_cH = set(unique_atoms).intersection(set(cH_atoms))
    unique_cH = tuple(unique_cH)

    # --- Create aligned maps --- #
    old_active_site = old_ring_match[0:3]

    maps = []
    for a in unique_cH:
        C_pos = lig_mol.GetAtomWithIdx(a)
        nbs = []
        for nb in C_pos.GetNeighbors():
            if nb.GetAtomicNum() == 1:
                pass # hydrogen
            else:
                nbs.append(nb.GetIdx())
        
        nbs.insert(1, C_pos.GetIdx())
        
        map = []
        for nb, aa in zip(nbs, old_active_site):
            map.append((nb, aa))
        maps.append(map)

    # --- Loop through each map a.k.a reactive position and create the molecule --- #
    params = ETKDGv3()
    params.randomSeed = 0xF00D  # Use any integer seed
    ts_mols = {}

    lig_mol_original = copy(lig_mol)

    for map in maps:
        lig_mol = lig_mol_original
        rpos = map[1][0]
        EmbedMolecule(lig_mol, params)
        ts_rw = Chem.RWMol(ts_rw_origin)
        rdMolAlign.AlignMol(lig_mol, ts_rw, atomMap=map)

        # --- remove hydrogen from the reacting carbon --- #
        chosen_carbon_idx = rpos
        chosen_carbon = lig_mol.GetAtomWithIdx(chosen_carbon_idx)

        for nb in chosen_carbon.GetNeighbors():
            if nb.GetAtomicNum() == 1:  # hydrogen
                lig_mol_rw = RWMol(lig_mol)
                lig_mol_rw.RemoveAtom(nb.GetIdx())
                lig_mol = lig_mol_rw.GetMol()
                break

        # --- Remove old ligand and determine bond order (to get aromaticity correct for the catalyst) --- #
        n_pattern_full = Chem.MolFromSmiles("C1CCCS1")
        n_old_indices = ts_rw.GetSubstructMatch(n_pattern_full)

        atoms_to_remove = set()
        for idx in n_old_indices:
            atom = ts_rw.GetAtomWithIdx(idx)
            atoms_to_remove.add(idx)
            for neighbor in atom.GetNeighbors():
                if neighbor.GetAtomicNum() == 1:  # Check if hydrogen
                    atoms_to_remove.add(neighbor.GetIdx())

        for idx in sorted(atoms_to_remove, reverse=True):
            ts_rw.RemoveAtom(idx)

        frags = Chem.GetMolFrags(ts_rw, asMols=True)

        rdDetermineBonds.DetermineBonds(frags[0])
        frag1 = fix_pin_frag(frags[1])

        ts_rw = RWMol(Chem.CombineMols(frags[0], frag1))

        # --- Combine ligand and catalyst, add temporary bonds, and set temporary formal charges ---
        ts_combined = Chem.CombineMols(ts_rw, lig_mol)
        ts_rw_combined = RWMol(ts_combined)

        offset = ts_rw.GetNumAtoms()
        reactive_C = rpos + offset

        cat_pat = Chem.MolFromSmarts('[B]-c1ccccc1-[N]')
        B_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][0]
        N_cat_idx = ts_rw_combined.GetSubstructMatches(cat_pat)[0][7]
        B_nbs = ts_rw_combined.GetAtomWithIdx(B_cat_idx).GetNeighbors()
        Hs_on_B = [nb.GetIdx() for nb in B_nbs if nb.GetAtomicNum() == 1]

        pin_pat = Chem.MolFromSmarts('[B]1OC(C(O1)(C)C)(C)C')
        B_pin_idx = ts_rw_combined.GetSubstructMatches(pin_pat)[0][0]
        B_pin_nbs = ts_rw_combined.GetAtomWithIdx(B_pin_idx).GetNeighbors()
        H_pin_idx = B_pin_nbs[0].GetIdx()

        ts_rw_combined.AddBond(reactive_C, B_pin_idx, Chem.BondType.SINGLE)

        atom_indices_to_keep = [B_cat_idx, N_cat_idx]
        atom_indices_to_keep.extend(Hs_on_B)
        atom_indices_to_keep.extend([B_pin_idx])
        atom_indices_to_keep.extend([H_pin_idx, reactive_C])

        if embed_ready:
            pass
            ts_rw_combined.AddBond(B_cat_idx, B_pin_idx, Chem.BondType.SINGLE)
            b_pin_atom = ts_rw_combined.GetAtomWithIdx(B_pin_idx)
            b_cat_atom = ts_rw_combined.GetAtomWithIdx(B_cat_idx)
            b_pin_atom.SetFormalCharge(2)
            b_cat_atom.SetFormalCharge(2)
            
            # c_atom = ts_rw_combined.GetAtomWithIdx(reactive_C)
            # c_atom.SetFormalCharge(0)

        mol_name = get_molecule_name(ligand_smiles)
        ts_mols[f'{pre_name}({mol_name}_rpos({rpos}))'] = (ts_rw_combined, atom_indices_to_keep, ligand_smiles)

    return ts_mols

ts_dict = transformer_ts4(embed_ready=True)

In [26]:
ts_dict

{'TS3(furan_rpos(0))': (<rdkit.Chem.rdchem.RWMol at 0x13db0b330>,
  [11, 10, 12, 13, 37, 35, 43],
  'C1=CC=CO1'),
 'TS3(furan_rpos(1))': (<rdkit.Chem.rdchem.RWMol at 0x13db27f60>,
  [11, 10, 12, 13, 37, 35, 44],
  'C1=CC=CO1')}

In [28]:
MolTo3DGrid(list(ts_dict.values())[0][0], kekulize=True)

In [33]:
embeds = embed_ts(ts_dict, ts_type="TS4-NEW", n_confs=10, optimize=True)

[11, 10, 12, 13, 37, 35, 43]
Embedded 9 conformers on atom 43
[11, 10, 12, 13, 37, 35, 44]
Embedded 7 conformers on atom 44


In [34]:
MolTo3DGrid(list(embeds.values())[0][0])