In [3]:
from typing import List, Tuple, Optional
from rdkit import Chem
from rdkit.Chem.rdchem import Mol
from synrbl.SynUtils.chem_utils import get_substruct_matches
import multiprocessing


class SubstructureAnalyzer:
    """
    A class for analyzing substructures within molecules using RDKit.

    Methods:
    remove_substructure_atoms: Removes a specified substructure from a
        molecule and returns the number of resulting fragments.
    sort_substructures_by_fragment_count: Sorts a list of substructures based
        on the number of fragments resulting from their removal.
    identify_optimal_substructure: Identifies the most relevant substructure
        from a list of potential substructures based on fragment count.
    """

    def __init__(self):
        pass

    def remove_substructure_atoms(
        self, parent_mol: Mol, substructure: Tuple[int, ...]
    ) -> int:
        """
        Removes specified atoms (substructure) from a molecule and returns the
        number of resulting fragments.

        Parameters:
        parent_mol (Mol): The parent molecule.
        substructure (Tuple[int, ...]): Indices of atoms in the substructure.

        Returns:
        int: The number of fragments resulting from the removal of the
            substructure.
        """
        rw_mol = Chem.RWMol(parent_mol)
        for atom_idx in sorted(substructure, reverse=True):
            if atom_idx < rw_mol.GetNumAtoms():
                rw_mol.RemoveAtom(atom_idx)
        new_mol = rw_mol.GetMol()
        return len(Chem.GetMolFrags(new_mol))

    def sort_substructures_by_fragment_count(
        self, substructures: List[Tuple[int, ...]], fragment_counts: List[int]
    ) -> List[Tuple[int, ...]]:
        """
        Sorts a list of substructures based on the number of fragments
        resulting from their removal.

        Parameters:
        substructures (List[Tuple[int, ...]]): List of substructures
            represented by atom indices.
        fragment_counts (List[int]): List of fragment counts corresponding to
            each substructure.

        Returns:
        List[Tuple[int, ...]]: Sorted list of substructures based on fragment
            counts.
        """
        paired_list = list(zip(substructures, fragment_counts))
        paired_list.sort(key=lambda x: x[1])
        return [pair[0] for pair in paired_list]

    def identify_optimal_substructure(
        self, parent_mol: Mol, child_mol: Mol, timeout_sec: int = 10
    ) -> Tuple[int, ...]:
        """
        Identifies the most relevant substructure within a parent molecule
        given a child molecule, with a timeout feature for the
        substructure matching process. If the primary matching process times out,
        a fallback search is attempted with a maximum of one match.

        Parameters:
        parent_mol (Mol): The parent molecule.
        child_mol (Mol): The child molecule whose substructures are to be analyzed.
        timeout_sec (int): Timeout in seconds for the substructure search process.

        Returns:
        Tuple[int, ...]: The atom indices of the identified substructure
        in the parent molecule.

        Returns:
        Tuple[int, ...]: The atom indices of the identified substructure in
            the parent molecule.
        """
        substructures = SubstructureAnalyzer.run_with_timeout(
            parent_mol, child_mol, timeout_sec
        )
        #substructures = parent_mol.GetSubstructMatches(child_mol)
        print(substructures)

        if len(substructures) > 1:
            fragment_counts = [
                self.remove_substructure_atoms(parent_mol, substructure)
                for substructure in substructures
            ]
            sorted_substructures = self.sort_substructures_by_fragment_count(
                substructures, fragment_counts
            )
            return sorted_substructures[0]
        else:
            return substructures[0] if substructures else ()

    @staticmethod
    def run_with_timeout(
        parent_mol: Mol, child_mol: Mol, timeout_sec: int
    ) -> Optional[List[Tuple[int, ...]]]:
        """
        Executes a substructure matching with a timeout limit.
        This method runs the substructure matching in a separate
        process to handle timeouts by terminating the process
        if it exceeds the given time limit. If the operation times out,
        a fallback search with maxMatches=1 is performed.

        Parameters:
        parent_mol (Mol): The parent molecule.
        child_mol (Mol): The child molecule.
        timeout_sec (int): The number of seconds to allow for the operation
                            before timing out.

        Returns:
        Optional[List[Tuple[int, ...]]]: A list of tuples representing the atom
        indices of the matches found, or None if the operation times out.
        """
        output = multiprocessing.Queue()
        process = multiprocessing.Process(
            target=get_substruct_matches, args=(output, parent_mol, child_mol)
        )
        process.start()
        process.join(timeout_sec)

        if process.is_alive():
            process.terminate()
            process.join()
            print("Timeout reached, process terminated. Trying with maxMatches=1.")
            return parent_mol.GetSubstructMatches(
                child_mol, maxMatches=1
            )  # Fallback if timeout is reached
        else:
            return output.get() if not output.empty() else None


In [4]:
from rdkit import Chem
from rdkit.Chem import rdFMCS
from rdkit.Chem import rdRascalMCES
#from synrbl.SynMCSImputer.SubStructure.substructure_analyzer import SubstructureAnalyzer


class MCSMissingGraphAnalyzer:
    """A class for detecting missing graph in reactants and products using MCS
    and RDKit."""

    def __init__(self):
        """Initialize the MolecularOperations class."""
        pass

    @staticmethod
    def get_smiles(reaction_dict):
        """
        Extract reactant and product SMILES strings from a reaction dictionary.

        Parameters:
        - reaction_dict: dict
            A dictionary containing 'reactants' and 'products' as keys.

        Returns:
        - tuple
            A tuple containing reactant SMILES and product SMILES strings.
        """
        return reaction_dict["reactants"], reaction_dict["products"]

    @staticmethod
    def convert_smiles_to_molecule(smiles):
        """
        Convert a SMILES string to an RDKit molecule object.

        Parameters:
        - smiles: str
            The SMILES string representing a molecule.

        Returns:
        - rdkit.Chem.Mol
            The RDKit molecule object.
        """
        return Chem.MolFromSmiles(smiles)

    @staticmethod
    def IterativeMCSReactionPairs(
        reactant_mol_list,
        product_mol,
        params=None,
        method="MCIS",
        sort="MCIS",
        remove_substructure=True,
    ):
        """
        Find the MCS for each reactant fragment with the product, updating the
        product after each step. Reactants are processed based on the size of
        their MCS with the product at each iteration.

        Parameters:
        - reactant_mol_list: list of rdkit.Chem.Mol
            List of RDKit molecule objects for reactants.
        - product_mol: rdkit.Chem.Mol
            RDKit molecule object for the product.
        - sort (str):
            Method of sorting reactants, either 'MCS' or 'Fragments'.
        - remove_substructure (bool):
            If True, update the product by removing the MCS substructure.
        - params (rdkit.Chem.rdFMCS.MCSParameters):
            Parameters for RDKit's rdFMCS.

        Returns:
        - list of rdkit.Chem.Mol
            List of RDKit molecule objects representing the MCS for each
            reactant-product pair.
        - list of rdkit.Chem.Mol
            Sorted list of reactant molecule objects.
        """

        # Sort reactants based on the specified method
        if sort == "MCIS":
            if params is None:
                params = rdFMCS.MCSParameters()
            mcs_results = [
                (reactant, rdFMCS.FindMCS([reactant, product_mol], params))
                for reactant in reactant_mol_list
            ]
            mcs_results = [
                (reactant, mcs_result)
                for reactant, mcs_result in mcs_results
                if not mcs_result.canceled
            ]
            sorted_reactants = sorted(
                mcs_results, key=lambda x: x[1].numAtoms, reverse=True
            )
        elif sort == "MCES":
            if params is None:
                params = rdRascalMCES.RascalOptions()
            mcs_results = [
                (reactant, rdRascalMCES.FindMCES(reactant, product_mol, params)[0])
                for reactant in reactant_mol_list
            ]
            mcs_results = [
                (reactant, mcs_result)
                for reactant, mcs_result in mcs_results
                if hasattr(mcs_result, "atomMatches")
            ]
            sorted_reactants = sorted(
                mcs_results, key=lambda x: len(x[1].atomMatches()), reverse=True
            )
        elif sort == "Fragments":
            sorted_reactants = sorted(
                reactant_mol_list, key=lambda x: x.GetNumAtoms(), reverse=True
            )
        else:
            raise ValueError("Invalid sort method. Choose 'MCS' or 'Fragments'.")

        mcs_list = []
        current_product = product_mol
        for reactant, _ in sorted_reactants:
            # Calculate the MCS with the current product
            try:
                if method == "MCIS":
                    mcs_result = rdFMCS.FindMCS([reactant, current_product], params)
                elif method == "MCES":
                    mcs_result = rdRascalMCES.FindMCES(
                        reactant, current_product, params
                    )[0]
                else:
                    raise ValueError("Invalid method. Choose 'MCIS' or 'MCES'.")

                if (
                    not mcs_result.canceled
                    if method == "MCIS"
                    else hasattr(mcs_result, "atomMatches")
                ):
                    mcs_smarts = (
                        mcs_result.smartsString
                        if method == "MCIS"
                        else mcs_result.smartsString.split(".")[0]
                    )
                    mcs_mol = Chem.MolFromSmarts(mcs_smarts)
                    mcs_list.append(mcs_mol)
                    # Conditional substructure removal
                    if remove_substructure:
                        # Identify the optimal substructure
                        analyzer = SubstructureAnalyzer()
                        optimal_substructure = analyzer.identify_optimal_substructure(
                            parent_mol=current_product, child_mol=mcs_mol
                        )
                        if optimal_substructure:
                            rw_mol = Chem.RWMol(current_product)
                            # Remove atoms in descending order of their indices
                            for atom_idx in sorted(optimal_substructure, reverse=True):
                                if (
                                    atom_idx < rw_mol.GetNumAtoms()
                                ):  # Check if the index is valid
                                    rw_mol.RemoveAtom(atom_idx)
                                else:
                                    pass
                            current_product = rw_mol.GetMol()

                    try:
                        Chem.SanitizeMol(current_product)
                    except Exception:
                        pass
            except Exception:
                mcs_list.append(None)
                pass

        return mcs_list, [reactant for reactant, _ in sorted_reactants]

    @staticmethod
    def fit(
        reaction_dict,
        RingMatchesRingOnly=True,
        CompleteRingsOnly=True,
        timeout=1,
        similarityThreshold=0.5,
        sort="MCIS",
        method="MCIS",
        remove_substructure=True,
        ignore_atom_map=False,
        ignore_bond_order=False,
    ):
        """
        Process a reaction dictionary to find MCS, missing parts in reactants
        and products.

        Parameters:
        - reaction_dict: dict
            A dictionary containing 'reactants' and 'products' as keys.

        Returns:
        - tuple
            A tuple containing lists of MCS, missing parts in reactants,
            missing parts in products, reactant molecules, and product
            molecules.
        """

        # define parameters

        if method == "MCIS":
            params = rdFMCS.MCSParameters()
            params.Timeout = timeout
            params.BondCompareParameters.RingMatchesRingOnly = RingMatchesRingOnly
            params.BondCompareParameters.CompleteRingsOnly = CompleteRingsOnly
            if ignore_bond_order:
                params.BondTyper = rdFMCS.BondCompare.CompareAny
            if ignore_atom_map:
                params.AtomTyper = rdFMCS.AtomCompare.CompareAny

        elif method == "MCES":
            params = rdRascalMCES.RascalOptions()
            params.singleLargestFrag = False
            params.returnEmptyMCES = True
            params.timeout = timeout
            params.similarityThreshold = similarityThreshold

        else:
            raise ValueError("Method '{}' is not implemented.".format(method))

        if reaction_dict["carbon_balance_check"] in ["products", "balanced"]:
            # Calculate the MCS for each reactant with the product
            reactant_smiles, product_smiles = MCSMissingGraphAnalyzer.get_smiles(
                reaction_dict
            )
            reactant_mol_list = [
                MCSMissingGraphAnalyzer.convert_smiles_to_molecule(smiles)
                for smiles in reactant_smiles.split(".")
            ]
            product_mol = MCSMissingGraphAnalyzer.convert_smiles_to_molecule(
                product_smiles
            )

            (
                mcs_list,
                sorted_parents,
            ) = MCSMissingGraphAnalyzer.IterativeMCSReactionPairs(
                reactant_mol_list,
                product_mol,
                params,
                method=method,
                sort=sort,
                remove_substructure=remove_substructure,
            )

            return mcs_list, sorted_parents, reactant_mol_list, product_mol

        elif reaction_dict["carbon_balance_check"] == "reactants":
            # Calculate the MCS for each product with the reactant
            reactant_smiles, product_smiles = MCSMissingGraphAnalyzer.get_smiles(
                reaction_dict
            )
            product_mol_list = [
                MCSMissingGraphAnalyzer.convert_smiles_to_molecule(smiles)
                for smiles in product_smiles.split(".")
            ]
            reactant_mol = MCSMissingGraphAnalyzer.convert_smiles_to_molecule(
                reactant_smiles
            )

            (
                mcs_list,
                sorted_parents,
            ) = MCSMissingGraphAnalyzer.IterativeMCSReactionPairs(
                product_mol_list,
                reactant_mol,
                params,
                method=method,
                sort=sort,
                remove_substructure=remove_substructure,
            )

            return mcs_list, sorted_parents, product_mol_list, reactant_mol
        else:
            raise RuntimeError(
                "Invalid carbon_balance_check value: '{}'".format(
                    reaction_dict["carbon_balance_check"]
                )
            )


In [5]:
rsmi ="COC(C)=O>>OC(C)=O"

reactant, product = rsmi.split(">>")

mcs = MCSMissingGraphAnalyzer()
mcs.IterativeMCSReactionPairs(
    [Chem.MolFromSmiles(reactant)], Chem.MolFromSmiles(product),
)

((0, 1, 2, 3),)


([<rdkit.Chem.rdchem.Mol at 0x135f8cdd0>],
 [<rdkit.Chem.rdchem.Mol at 0x135f8d850>])