In [None]:
from SynTemp.SynUtils.utils import load_database, load_from_pickle
data = load_from_pickle('./Data/uspto/uspto_its_graph_rules_cluster.pkl.gz')

In [None]:
from SynTemp.SynRule.rule_cluster import NaiveCluster
node_label_names = ["element", "charge"]
naive_cluster = NaiveCluster(node_label_names=node_label_names, node_label_default=["*", 0], edge_attribute="order")
its_graph_rules_cluster = naive_cluster.process_rules_clustering(data, rule_column='GraphRules')

In [None]:
its_graph_rules_cluster[0]

In [None]:
from SynTemp.SynUtils.utils import stratified_random_sample
import pandas as pd
sampled_data = stratified_random_sample(its_graph_rules_cluster, property_key='naive_cluster', samples_per_class=1, seed=23)
pd.DataFrame(sampled_data)['Reaction Type'].value_counts()

In [None]:
single = [x['GraphRules'][2] for x in sampled_data if x['Reaction Type'] == 'Single Cyclic']
complex = [x['GraphRules'][2] for x in sampled_data if x['Reaction Type'] == 'Complex Cyclic']
neither = [x['GraphRules'][2] for x in sampled_data if x['Reaction Type'] == 'None']
acyclic = [x['GraphRules'][2] for x in sampled_data if x['Reaction Type'] == 'Acyclic']

In [None]:
from SynTemp.SynVis.chemical_graph_vis import ChemicalGraphVisualizer
vis = ChemicalGraphVisualizer()
vis.graph_vis(complex[4], show_node_labels=True)

In [None]:
from SynTemp.SynRule.rule_decompose import GraphRuleDecompose

In [None]:
from copy import deepcopy
complex_graph = complex[11]
# Add nodes and edges to complex_graph with the required attributes

single_cyclic_graphs = deepcopy(single)
# Define your single cyclic graphs by adding nodes and edges with the required attributes

# Call the function
explained_graphs = GraphRuleDecompose.bfs_remove_isomorphic_subgraphs(complex_graph, single_cyclic_graphs)

if explained_graphs is not None:
    print("List of single cyclic graphs that explain the complex graph:", explained_graphs)
    GraphRuleDecompose.visualize_with_common_subgraphs(complex_graph, explained_graphs)
else:
    print("Some parts of the complex graph could not be explained by any of the single cyclic graphs.")

In [None]:
import pandas as pd 

df = pd.read_csv('./Data/golden/golden_dataset.csv')
df.head(2)

from SynTemp.SynUtils.utils import load_database
data = load_database('./Data/golden/golden_aam_reactions.json.gz')
for key, value in enumerate(data):
    data[key]['ground_truth'] = df.iloc[key,0]

In [None]:
from SynTemp.SynUtils.utils import load_database
data = load_database('./Data/golden/golden_aam_reactions.json.gz')
for key, value in enumerate(data):
    data[key]['ground_truth'] = df.iloc[key,0]

In [None]:
pd.DataFrame(data).info()

In [None]:
from SynTemp.SynAAM.aam_validator import AMMValidator 
for key, value in enumerate(data):
    try:
        AMMValidator.smiles_check(data[key]['ground_truth'], data[key]['rdt'])
    except:
        print(key)

In [None]:
data[366]['rdt']

In [None]:
from SynTemp.SynAAM.aam_validator import AMMValidator  
results = AMMValidator.validate_smiles(data=data, ground_truth_col='ground_truth', 
                                       mapped_cols=['rxn_mapper', 'graphormer', 'local_mapper', 'rdt'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0)

In [None]:
pd.DataFrame(results)[['mapper', 'accuracy']]

In [None]:
from SynTemp.SynAAM.aam_validator import AMMValidator  
results = AMMValidator.validate_smiles(data=data, ground_truth_col='Ground turth', 
                                       mapped_cols=['RXNMapper', 'GraphMapper', 'LocalMapper'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0)

In [None]:
pd.DataFrame(results)[['mapper', 'accuracy']]

In [None]:
pd.DataFrame(results)[['mapper', 'accuracy']]

In [None]:
from SynTemp.SynUtils.utils import load_database
recon = load_database('./Data/Recon3D/Recon3D_aam_reactions.json.gz')

from SynTemp.SynAAM.aam_validator import AMMValidator  
results, _ = AMMValidator.validate_smiles(data=recon, ground_truth_col='ground_truth', 
                                       mapped_cols=['rxn_mapper', 'graphormer', 'local_mapper', 'rdt'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0, ensemble=True)

In [None]:
pd.DataFrame(recon).to_csv('./Data/Recon3D/Recon3D_aam_reactions.csv')

In [None]:
import pandas as pd
pd.DataFrame(results)[['mapper', 'accuracy', 'success_rate']]

In [None]:
import pandas as pd

In [None]:
ecoli = pd.read_csv('./Data/ecoli/ecoli.smiles', header=None)
ecoli.rename({0:'ground_truth'}, axis=1, inplace=True)
ecoli['R-id'] = range(1, len(ecoli) + 1)

In [None]:
from rdkit import Chem
from rdkit.Chem import rdChemReactions
ok = []
bug = []
for key, value in enumerate(ecoli['ground_truth']):
    try:
        rdChemReactions.ReactionFromSmarts(value)
        ok.append(key)
    except:
        bug.append(key)

In [None]:
a,b,c =ecoli['ground_truth'][bug[0]].split('>>')

In [None]:
print(ecoli.iloc[bug,:]['ground_truth'])

In [None]:
ecoli = ecoli.iloc[ok, :]
ecoli.reset_index(drop=True, inplace=True)
#ecoli = ecoli.to_dict('records')

In [None]:
from SynTemp.SynUtils.utils import save_database
save_database(ecoli, './Data/ecoli/ecoli_reactions.json.gz')

In [None]:
Chem.MolFromSmiles(ecoli.loc[189, 'reactions'])

In [None]:
bug

In [None]:

rxn = rdChemReactions.ReactionFromSmarts('[C:1](=[O:2])O.[N:3]>>[C:1](=[O:2])[N:3]')
reacts = (Chem.MolFromSmiles('C(=O)O'),Chem.MolFromSmiles('CNC'))
products = rxn.RunReactants(reacts)

In [None]:
from SynTemp.SynUtils.utils import load_database
recon = load_database('./Data/Recon3D/Recon3D_aam_reactions.json.gz')

from SynTemp.SynAAM.aam_validator import AMMValidator  
results, _ = AMMValidator.validate_smiles(data=recon, ground_truth_col='ground_truth', 
                                       mapped_cols=['rxn_mapper', 'graphormer', 'local_mapper'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0, ensemble=True)

import pandas as pd
pd.DataFrame(results)[['mapper', 'accuracy', 'success_rate']]

In [None]:
recon[0]

In [None]:
from SynTemp.SynUtils.utils import load_database
recon = load_database('./Data/ecoli/ecoli_aam_reactions.json.gz')

from SynTemp.SynAAM.aam_validator import AMMValidator  
results, _ = AMMValidator.validate_smiles(data=recon, ground_truth_col='ground_truth', 
                                       mapped_cols=['rxn_mapper', 'graphormer', 'local_mapper', 'rdt', 'ground_truth'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0, ensemble=False)

import pandas as pd
pd.DataFrame(results)[['mapper', 'accuracy', 'success_rate']]

In [None]:
test = pd.DataFrame(recon).drop(['reactions'], axis =1)
test['local_mapper_result'] = pd.DataFrame(results).loc[2, 'results']
test['rxn_mapper_result'] = pd.DataFrame(results).loc[0, 'results']
test['graphormer_result'] = pd.DataFrame(results).loc[1, 'results']
test['rdt_result'] = pd.DataFrame(results).loc[3, 'results']

In [None]:
test.to_csv('./Data/ecoli/ecoli_aam_reactions.csv')

In [None]:
pd.DataFrame(results)

In [None]:
pd.DataFrame(results).loc[2, 'results']

In [None]:
test_2 = test[['local_mapper_result', 'rxn_mapper_result', 'graphormer_result', 'rdt_result']]

In [None]:
def ensemble_results(df, threshold):
    # Calculate the sum of True values in each row
    true_counts = df.sum(axis=1)
    # Apply the threshold to determine the final result
    final_results = true_counts >= threshold
    return final_results

test_3 = ensemble_results(test_2, 2)

In [None]:
test_3.sum()/273

In [None]:
uspto_sample = pd.read_csv('./Data/aam_benchmark/USPTO_sampled.csv')

In [None]:
uspto_sample['LocalMapper_correct'].sum()

In [None]:
uspto_sample

# Bug

In [None]:
import pandas as pd 

df = pd.read_csv('./Data/aam_benchmark/Golden_mappings.csv')
df = pd.read_csv('./Data/aam_benchmark/benchmark.csv')
df = pd.read_csv('./Data/aam_benchmark/NatComm_mappings.csv')
df = pd.read_csv('./Data/aam_benchmark/USPTO_sampled.csv')
df = pd.read_csv('./Data/ecoli/ecoli_aam_reactions.csv')
#df = pd.read_csv('./Data/Recon3D/Recon3D_aam_reactions.csv')
df.head(1)

In [None]:
from SynTemp.SynAAM.aam_validator import AMMValidator  
results, _ = AMMValidator.validate_smiles(data=df, ground_truth_col='ground_truth', 
                                       mapped_cols=['ground_truth'], 
                                       check_method='RC', 
                                       ignore_aromaticity=False, n_jobs=4, verbose=0, ensemble=False)

import pandas as pd
pd.DataFrame(results)[['mapper', 'accuracy', 'success_rate']]

In [None]:
test_0 = df[['ground_truth', 'R-id']]
test_0['results'] = results[0]['results']

In [None]:
test_0_bug = test_0.loc[test_0['results']==False, :]

In [None]:
rsmi = test_0_bug['ground_truth'][74]

In [None]:
from SynTemp.SynITS.its_construction import ITSConstruction
from SynTemp.SynITS.its_extraction import ITSExtraction
reactants, products = rsmi.split('>>')
G, H = ITSExtraction.graph_from_smiles(reactants), ITSExtraction.graph_from_smiles(products)

In [None]:
rsmi

In [None]:
reactants

In [None]:
from rdkit import Chem
Chem.MolFromSmiles(reactants)

In [None]:
from SynTemp.SynVis.reaction_visualizer import ReactionVisualizer
vis = ReactionVisualizer()

In [None]:
vis.visualize_reaction(test_0_bug.iloc[0]['ground_truth'])

In [None]:
test_0_bug.iloc[0]['ground_truth']

In [None]:
import pandas as pd


df = pd.read_csv('./Data/USPTO_50K/USPTO_50K.csv')
df.drop_duplicates(subset=['reactions'], inplace=True)
df['R-id'] = ['USPTO-' + str(i) for i in range(len(df))]
df.head(2)

# MOD

In [None]:
from SynTemp.SynRule.rule_executor import RuleExecutor
test = RuleExecutor.reaction_prediction(input_smiles=['C=C1C(=C)C2OC1C1=C2CC(C(C)=O)CC1'],
                                        rule_file_path='./Data/uspto/Rule/USPTO_50K_31.gml',
                                        prediction_type='backward', repeat_times=1, print_results=False)

test

In [None]:
from SynTemp.SynRule.rule_executor import RuleExecutor
from SynTemp.SynUtils.utils import load_database
database = load_database('./test_database.json.gz')
test = RuleExecutor.reaction_database_prediction(database=database[:],  rule_file_path='./Data/uspto/Rule/',
                                         original_rsmi_col='reactions', prediction_type = 'backward', repeat_times=1)


In [None]:
from SynTemp.SynRule.rule_benchmark import RuleBenchmark
from SynTemp.SynUtils.utils import load_database
database = load_database('./test_database.json.gz')
fw, bw = RuleBenchmark.reproduce_reactions(database=database[:],  id_col='R-id', rule_file_path='./Data/uspto/Rule',
                                         original_rsmi_col='reactions', repeat_times=1)

In [None]:
import pandas as pd
pd.DataFrame(bw).info()

# Ranking

In [None]:
from SynTemp.SynUtils.utils import load_database
from SynTemp.SynRule.rule_benchmark import RuleBenchmark
database = load_database('./test_database.json.gz')
fw, bw = RuleBenchmark.reproduce_reactions(database=database[:],  id_col='R-id', rule_file_path='./Data/uspto/Rule',
                                         original_rsmi_col='reactions', repeat_times=1, prior=True)

In [None]:
import pandas as pd
pd.DataFrame(bw).info()

In [None]:
from SynTemp.SynRule.similarity_ranking import SimilarityRanking


processed_dicts = SimilarityRanking.process_list_of_dicts(fw, 'unrank', ['FCFP6'])
print("Top 5 accuracy:", RuleBenchmark.TopKAccuracy(processed_dicts, 'reactions','rank', 2, ignore_stero=True))

In [None]:
processed_dicts = SimilarityRanking.process_list_of_dicts(fw, 'unrank', ['RDK7'])
print("Top 5 accuracy:", RuleBenchmark.TopKAccuracy(processed_dicts, 'reactions','rank', 5, ignore_stero=True))

## Visualize

In [None]:
from SynTemp.SynUtils.utils import load_database
data = load_database('./Data/uspto_sample/uspto_sample_aam_reactions.json.gz')

In [None]:
from SynTemp.SynVis.chemical_reaction_visualizer import ChemicalReactionVisualizer
vis = ChemicalReactionVisualizer()
vis.visualize_and_compare_reactions(data[0], num_cols=3)

In [None]:
from SynTemp.SynVis.its_visualizer import ITSVisualizer
from IPython.display import Image
its_vis = ITSVisualizer(data[0]['rxn_mapper'])
display(Image(its_vis.draw_product_with_modified_bonds()))

In [None]:
import pandas as pd

df = pd.read_csv('./Data/USPTO_50K/USPTO_50K.csv')

In [None]:
df['reactions'][0]

In [None]:
from SynTemp.SynVis.chemical_reaction_visualizer import ChemicalReactionVisualizer
vis = ChemicalReactionVisualizer()
vis.visualize_reaction(df['reactions'][2], show_atom_map=False)

In [None]:
df['reactions'][2]

In [None]:
Chem.MolFromSmiles('C(O)(O)C=CO')

In [None]:
from fgutils import FGQuery
smiles = "C(O)(O)C=CO" # acetylsalicylic acid
query = FGQuery(use_smiles=True) # use_smiles requires rdkit to be installed
query.get(smiles)

In [None]:
Chem.MolFromSmiles('C(O)(O)')

In [None]:
Chem.MolFromSmiles('C=CO')

In [None]:
import pandas as pd

df = pd.read_csv('./Data/uspto_sample/uspto_sample.csv')
df['GroundTruth'] = df['GraphMapper']
df.loc[df['GraphMapper_correct']==False, 'GroundTruth'] = df.loc[df['GraphMapper_correct']==False, 'LocalMapper']
df = df[['sampled_id', 'RXNMapper', 'GraphMapper', 'LocalMapper', 'GroundTruth']]

In [None]:
df.to_csv('uspto_3k')

In [None]:
df = pd.read_csv('./Data/uspto_sample/uspto_sample.csv')

In [None]:
df['GraphMapper_correct'].sum()

In [None]:
df

In [None]:
from rdkit import Chem
mol = Chem.MolFromSmiles('[BH3-][NH3+]')

In [None]:
mol

In [None]:
Chem.MolFromSmiles('[Na+]')

In [None]:
from rdkit.Chem.MolStandardize import rdMolStandardize
def uncharge_molecule(mol: Chem.Mol) -> Chem.Mol:
    """
    Neutralize a molecule by removing counter-ions using RDKit's Uncharger.

    Args:
        mol: RDKit Mol object.

    Returns:
        Neutralized Mol object.
    """
    uncharger = rdMolStandardize.Uncharger()
    return uncharger.uncharge(mol)

In [None]:
from SynTemp.SynStandardizer.deionize import Deionize
smiles = "[NH4+].[Cl-]"
uncharged_smiles = Deionize.uncharge_smiles(smiles)
uncharged_smiles

In [None]:
Chem.CanonSmiles('[Na]O')

In [None]:
uncharge_molecule(mol)

In [None]:
Chem.CanonSmiles('C[N+](C)(C)C.[Cl-]')

In [None]:
from rdkit import Chem
from joblib import Parallel, delayed
from typing import List, Dict, Union, Tuple
from SynTemp.SynUtils.chemutils import get_combined_molecular_formula


class BalanceReactionCheck:
    """
    A class to check the balance of chemical reactions given in SMILES format.
    It supports parallel execution and maintains the input format in the output.
    """

    def __init__(
        self,
        n_jobs: int = 4,
        verbose: int = 0,
    ):
        """
        Initializes the class with given input data, the column name for reactions in the input,
        number of jobs for parallel processing, and verbosity level.

        Parameters:
        - input_data (Union[str, List[Union[str, Dict[str, str]]]]): A single SMILES string,
          a list of SMILES strings, or a list of dictionaries with 'reactions' keys.
        - rsmi_column (str): The key/column name for reaction SMILES strings in the input data.
        - n_jobs (int): The number of parallel jobs to run for balance checking (default: -1, using all processors).
        - verbose (int): The verbosity level of joblib parallel execution (default: 0).
        """
        
        self.n_jobs = n_jobs
        self.verbose = verbose

    @staticmethod
    def parse_input(
        input_data: Union[str, List[Union[str, Dict[str, str]]]],
        rsmi_column: str = "reactions",
    ) -> List[Dict[str, str]]:
        """
        Parses the input data into a standardized list containing dictionaries for each reaction.

        Parameters:
        - input_data (Union[str, List[Union[str, Dict[str, str]]]]): The input data to be processed.

        Returns:
        - List[Dict[str, str]]: A list of dictionaries with reaction SMILES strings.
        """
        standardized_input = []
        if isinstance(input_data, str):
            standardized_input.append({rsmi_column: input_data})
        elif isinstance(input_data, list):
            for item in input_data:
                if isinstance(item, str):
                    standardized_input.append({rsmi_column: item})
                elif isinstance(item, dict) and rsmi_column in item:
                    standardized_input.append(item)
        else:
            raise ValueError("Unsupported input type")
        return standardized_input

    @staticmethod
    def parse_reaction(reaction_smiles: str) -> Tuple[List[str], List[str]]:
        """
        Splits a reaction SMILES string into reactants and products.

        Parameters:
        - reaction_smiles (str): A SMILES string representing a chemical reaction.

        Returns:
        - Tuple[List[str], List[str]]: Lists of SMILES strings for reactants and products.
        """
        reactants_smiles, products_smiles = reaction_smiles.split(">>")
        return reactants_smiles, products_smiles

    @staticmethod
    def rsmi_balance_check(reaction_smiles:str):
        reactants_smiles, products_smiles = BalanceReactionCheck.parse_reaction(reaction_smiles)
        reactants_forumula = get_combined_molecular_formula(reactants_smiles)
        products_forumula = get_combined_molecular_formula(products_smiles)
        if reactants_forumula != products_forumula:
            return False
        else:
            return True

    @staticmethod
    def dict_balance_check(
        reaction_dict: Dict[str, str], rsmi_column: str
    ) -> Dict[str, Union[bool, str]]:
        """
        Checks if a single reaction (in SMILES format) is balanced, maintaining the input format.

        Parameters:
        - reaction_dict (Dict[str, str]): A dictionary containing the reaction SMILES string.

        Returns:
        - Dict[str, Union[bool, str]]: A dictionary indicating if the reaction is balanced,
          along with the original reaction data.
        """
        reaction_smiles = reaction_dict[rsmi_column]
        balance = BalanceReactionCheck.rsmi_balance_check(reaction_smiles)
        return {"balanced": balance, **reaction_dict}

    def dicts_balance_check(
        self,
        input_data: Union[str, List[Union[str, Dict[str, str]]]],
        rsmi_column: str = "reactions",
    ) -> Tuple[List[Dict[str, Union[bool, str]]], List[Dict[str, Union[bool, str]]]]:
        """
        Checks the balance of all reactions in the input data.

        Returns:
        - Tuple[List[Dict[str, Union[bool, str]]], List[Dict[str, Union[bool, str]]]]: Two lists containing dictionaries
          of balanced and unbalanced reactions, respectively.
        """
       
        reactions = self.parse_input(input_data, rsmi_column)
        results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
            delayed(self.dict_balance_check)(reaction, rsmi_column)
            for reaction in reactions
        )

        balanced_reactions = [reaction for reaction in results if reaction["balanced"]]
        unbalanced_reactions = [
            reaction for reaction in results if not reaction["balanced"]
        ]

        return balanced_reactions, unbalanced_reactions


In [None]:
Chem.CanonSmiles('O=O')

In [None]:
from SynTemp.SynUtils.utils import  load_database
data_test = load_database('./Data/uspto/uspto_balance_reactions.json.gz')

check = BalanceReactionCheck(n_jobs=1, verbose=2)
check.dicts_balance_check(data_test[0:10000], 'reactions')

In [None]:
import importlib.resources
data_path = importlib.resources.files('SynTemp').joinpath("uspto_its_incorrect.pkl.gz") / 'Data' / 'uspto' / 'uspto_its_incorrect.pkl.gz'

In [74]:
from SynTemp.SynUtils.utils import load_database, load_from_pickle
data = load_from_pickle('./Data/uspto/uspto_its_incorrect.pkl.gz')

In [75]:
data[0]

{'R-id': 'USPTO_50K_31',
 'rxn_mapper': (<networkx.classes.graph.Graph at 0x7fe3657b44d0>,
  <networkx.classes.graph.Graph at 0x7fe3656c2150>,
  <networkx.classes.graph.Graph at 0x7fe365449510>),
 'graphormer': (<networkx.classes.graph.Graph at 0x7fe3654a7310>,
  <networkx.classes.graph.Graph at 0x7fe3654a73d0>,
  <networkx.classes.graph.Graph at 0x7fe3653e6850>),
 'local_mapper': (<networkx.classes.graph.Graph at 0x7fe365404290>,
  <networkx.classes.graph.Graph at 0x7fe365404110>,
  <networkx.classes.graph.Graph at 0x7fe365406a50>),
 'equivariant': 0,
 'ITSGraph': (<networkx.classes.graph.Graph at 0x7fe365419190>,
  <networkx.classes.graph.Graph at 0x7fe36541aa10>,
  <networkx.classes.graph.Graph at 0x7fe36541ba10>),
 'GraphRules': (<networkx.classes.graph.Graph at 0x7fe3654d0e90>,
  <networkx.classes.graph.Graph at 0x7fe3654d1d90>,
  <networkx.classes.graph.Graph at 0x7fe3654d1ad0>)}