In [None]:
import sys
import json
import gzip
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit import RDLogger
from joblib import Parallel, delayed

# Set up the system path to include parent directory
sys.path.append('../../')

def configure_logging():
    """
    Configure RDKit logging to suppress warnings and errors.
    """
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.ERROR)

def load_data(file_path):
    """
    Load the dataset from a CSV file.

    Args:
        file_path (str): Path to the CSV file.

    Returns:
        DataFrame: A pandas DataFrame containing the loaded data.
    """
    try:
        df = pd.read_csv(file_path)
        return df
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
        sys.exit(1)



if __name__ == "__main__":
    configure_logging()
    df = load_data('../../Data/Validation_set/USPTO_50K.csv')

# **1. SynExtract**

Here's a brief description of how we can approach the task:

1. **Input and Parsing of Chemical Reactions**: The script will accept chemical reactions as input. Each reaction will be a string where reactants and products are separated by an arrow (e.g., A + B -> C + D). The script will parse these strings to identify reactants and products.

2. **Standardization of Chemical Formulas**: The script will standardize the chemical formulas to ensure consistency. This involves formatting element symbols and quantities properly.

3. **Decomposition into Molecular Formulas**: Each reactant and product will be broken down into its molecular formula. For example, decomposing H2O into {'H': 2, 'O': 1}.

4. **Labeling Types of Reactions**:

    - **Balance Check**: The script will check if the reaction is balanced by comparing the count of each element on both sides of the reaction.
    - **Labeling**: Based on the balance check, reactions will be labeled as:
        - **Balanced**: If the number of each type of atom is the same on both sides.
        - **Unbalanced in Reactants**: If any reactant element is not balanced.
        - **Unbalanced in Products**: If any product element is not balanced.
        - **Unbalanced in Both**: If elements in both reactants and products are unbalanced.

## **1.1. Extract SMILES**

In [None]:
from SynRBL.SynProcessor import RSMIProcessing
process = RSMIProcessing(data=df, data_name='USPTO_50K', rsmi_col='reactions', parallel=True, n_jobs=10, 
                            save_json =False, save_path_name= None)
reactions = process.data_splitter().to_dict('records')
reactions[0]


# **1.2. Check Carbon Balance**

In [None]:
from SynRBL.SynProcessor import CheckCarbonBalance
check = CheckCarbonBalance(reactions, rsmi_col='reactions', symbol='>>', atom_type='C', n_jobs=4)
reactions = check.check_carbon_balance()
reactions[0]

In [None]:
rules_based = [reactions[key] for key, value in enumerate(reactions) if value['carbon_balance_check'] == 'balanced']
mcs_based = [reactions[key] for key, value in enumerate(reactions) if value['carbon_balance_check'] != 'balanced']
print(len(rules_based), len(mcs_based))

## **1.3. Molecular Decomposer**

In [None]:
from SynRBL.SynProcessor import RSMIDecomposer  
decompose = RSMIDecomposer(smiles=None, data=rules_based, reactant_col='reactants', product_col='products', parallel=True, n_jobs=-1, verbose=1)
react_dict, product_dict = decompose.data_decomposer()
react_dict[0]

In [None]:
react_dict[0]

In [None]:
product_dict[0]

## **1.4. Molecular Comparator**

In [None]:
from SynRBL.SynProcessor import RSMIComparator
from SynRBL.SynUtils.data_utils import save_database, load_database
import pandas as pd

In [None]:
comp = RSMIComparator(reactants= react_dict, products=product_dict, n_jobs=-1)
unbalance , diff_formula= comp.run_parallel(reactants= react_dict, products=product_dict)

## **1.5. Both side reactions rules**

In [None]:
from SynRBL.SynProcessor import BothSideReact
both_side = BothSideReact(react_dict, product_dict, unbalance, diff_formula)
diff_formula, unbalance= both_side.fit()

In [None]:
reactions_clean = pd.concat([pd.DataFrame(rules_based), pd.DataFrame([unbalance]).T.rename(columns={0:'Unbalance'}),
           pd.DataFrame([diff_formula]).T.rename(columns={0:'Diff_formula'})], axis=1).to_dict(orient='records')
reactions_clean[0]

# **2. SynRuleEngine - Rule Generation**

## **2.1. Manual Rules Extraction**

In [None]:
from SynRBL.rsmi_utils import save_database, load_database, filter_data, sort_by_key_length
from SynRBL.SynRuleImputer.rule_data_manager import RuleImputeManager
   

rules = []
former_len = len(rules)
db = RuleImputeManager(rules)

entries = [{'formula': 'CO2', 'smiles': 'C=O'}, {'formula': 'Invalid', 'smiles': 'Invalid'}]
invalid_entries = db.add_entries(entries)
print(f"Invalid entries: {invalid_entries}")

rules = filter_data(db.database, formula_key='Composition', element_key='C', min_count=0, max_count=1)
rules = sort_by_key_length(db.database, lambda x: x['Composition'])

rules

## **2.2. Automatic rule extraction**

In [None]:
from SynRBL.SynRuleImputer.auto_extract_smiles import AutomaticSmilesExtraction
from SynRBL.SynRuleImputer.auto_extract_rules import AutomaticRulesExtraction

# Create an instance of the AutomaticSmilesExtraction class with parallel processing
smi_extractor = AutomaticSmilesExtraction(reactions_clean, n_jobs=4, verbose=1)

# Example usage of get_fragments
input_dict = {
    'smiles': smi_extractor.smiles_list,
    'mw': smi_extractor.mw,
    'n_C': smi_extractor.n_C
}
filtered_fragments = AutomaticSmilesExtraction.get_fragments(input_dict, mw=500, n_C=0, combination='intersection')
print("Filtered Fragments:", len(filtered_fragments))


extractor = AutomaticRulesExtraction(existing_database=[], n_jobs=-1, verbose=1)
extractor.add_new_entries(filtered_fragments)
automated_rules = extractor.extract_rules()
print("Extracted Rules:", len(automated_rules))

# **3. SynRuleImpute**


## **3.1. Rule-based Imputation**

In [None]:
from SynRBL.SynUtils.data_utils import save_database, load_database, filter_data, extract_results_by_key
from SynRBL.SynRuleImputer import SyntheticRuleImputer

rules = load_database('../../Data/Rules/rules_manager.json.gz')
#reactions_clean = load_database('../../Data/reaction_clean.json.gz')

# Filter data based on specified criteria

balance_reactions = filter_data(reactions_clean, unbalance_values=['Balance'], 
                                formula_key='Diff_formula', element_key=None, min_count=0, max_count=0)
print('Number of Balanced Reactions:', len(balance_reactions))

unbalance_reactions = filter_data(reactions_clean, unbalance_values=['Reactants', 'Products'], 
                                formula_key='Diff_formula', element_key=None, min_count=0, max_count=0)
print('Number of Unbalanced Reactions in one side:', len(unbalance_reactions))

both_side_reactions = filter_data(reactions_clean, unbalance_values=['Both'], 
                                    formula_key='Diff_formula', element_key=None, min_count=0, max_count=0)
print('Number of Both sides Unbalanced Reactions:', len(both_side_reactions))

In [None]:
balance_reactions

In [None]:
# Configure RDKit logging
from rdkit import Chem
import rdkit
lg = RDLogger.logger()
lg.setLevel(RDLogger.ERROR)
RDLogger.DisableLog('rdApp.info') 
rdkit.RDLogger.DisableLog('rdApp.*')

# Initialize SyntheticRuleImputer and perform parallel imputation
imp = SyntheticRuleImputer(rule_dict=rules, select='all', ranking='ion_priority')
expected_result = imp.parallel_impute(unbalance_reactions)

# Extract solved and unsolved results
solve, unsolve = extract_results_by_key(expected_result)
print('Solved:', len(solve))
print('Unsolved in rules based method:', len(unsolve))



# Combine all unsolved cases
unsolve = both_side_reactions + unsolve
print('Total unsolved:', len(unsolve))

## **3.2. Uncertainty Reaction**

In [None]:
from SynRBL.rsmi_utils import  save_database, load_database
from SynRBL.SynRuleImputer.synthetic_rule_constraint import RuleConstraint
constrain = RuleConstraint(solve, ban_atoms=['[H]','[O].[O]', 'F-F', 'Cl-Cl', 'Br-Br', 'I-I', 'Cl-Br', 'Cl-I', 'Br-I'])
certain_reactions, uncertain_reactions = constrain.fit()

id_uncertain = [entry['R-id'] for entry in uncertain_reactions]
new_uncertain_reactions = [entry for entry in reactions_clean if entry['R-id'] in id_uncertain]

unsolve = unsolve + new_uncertain_reactions


for d in unsolve:
    d.pop('Unbalance', None)  # Remove 'Unbalance' key if it exists
    d.pop('Diff_formula', None)  # Remove 'Diff_formula' key if it exists

mcs_based = mcs_based+unsolve

In [None]:
# Save solved and unsolved reactions
#save_database(certain_reactions,  '../../Data/rule_based_reactions.json.gz')
#save_database(mcs_based,  '../../Data/mcs_based_reactions.json.gz')

## **3.2. Visualization**

In [None]:
from SynRBL.SynUtils.data_utils import load_database, get_random_samples_by_key
certain_reactions = load_database('../../Data/Validation_set/USPTO_50K/Solved_reactions.json.gz')
uncertain_reactions = load_database('../../Data/Validation_set/USPTO_50K/Unsolved_reactions.json.gz')
validate_samples = get_random_samples_by_key(certain_reactions, num_samples_per_group=10, random_seed=42, stratify_key = 'Diff_formula')
#save_database(validate_samples, '../../Data/validate_samples.json.gz')
len(validate_samples)


In [None]:
from SynRBL.SynVis.reaction_visualizer import ReactionVisualizer
vis = ReactionVisualizer()
for i in range(0, 5,1):
    try:
        vis = ReactionVisualizer()
        vis.plot_reactions(uncertain_reactions[i], 'reactions', 'new_reaction', show_atom_numbers=False, compare= True, savefig=False, pathname = None, dpi=300)
    except:
        vis = ReactionVisualizer()
        vis.plot_reactions(uncertain_reactions[i], 'reactions', 'reactions', show_atom_numbers=False, compare= True, savefig=False, pathname = None, dpi=300)


# **4. MCS Rebalancing**

## **4.1. Maximum common substructure**

In [None]:
import pandas as pd
import numpy as np
from joblib import Parallel, delayed
import sys
from pathlib import Path
from SynRBL.SynMCSImputer.SubStructure.mcs_graph_detector import MCSMissingGraphAnalyzer
from SynRBL.SynMCSImputer.SubStructure.mcs_process import single_mcs
from SynRBL.SynUtils.data_utils import load_database, save_database
from rdkit import Chem
import logging




#unsolve = load_database('../../Data/Validation_set/USPTO_50K/Unsolved_reactions.json.gz')
mcs_results_dict = single_mcs(mcs_based[0])
mcs_results_dict

In [None]:
from SynRBL.SynVis.mcs_visualizer import  MCSVisualizer
mcs_vis = MCSVisualizer()
img= mcs_vis.highlight_molecule(mcs_results_dict['sorted_reactants'][0], mcs_results_dict['mcs_results'][0], show_atom_numbers=True, compare=False, missing_graph_smiles=None)
img

## **4.2. Find Graph**

In [None]:
from SynRBL.SynMCSImputer.MissingGraph.find_graph_dict import FindMissingGraphs
from SynRBL.SynVis.mcs_visualizer import  MCSVisualizer
mol_list = [Chem.MolFromSmiles(i) for i in mcs_results_dict['sorted_reactants']]
mcs_list = [Chem.MolFromSmiles(i) for i in mcs_results_dict['mcs_results']]

find_graph = FindMissingGraphs()
missing_graph, boundary, neighbor = find_graph.find_missing_parts_pairs(mol_list=mol_list, mcs_list=mcs_list,use_findMCS=True)




print('Reactants')
print('Neighbor Atom:', neighbor)
print('Missing compounds in Products')
print('Boundary Atom:', boundary)

mcs_vis = MCSVisualizer()
img= mcs_vis.highlight_molecule(mcs_results_dict['sorted_reactants'][0], mcs_results_dict['mcs_results'][0], show_atom_numbers=True, compare=True, missing_graph_smiles=Chem.MolToSmiles(missing_graph[0]))
img


## **4.3. Mol Merge**

## **4.4. Check and Re-Impute by Rule based method**