In [1]:
import sys
sys.path.append('../../../')
from SynRBL.rsmi_utils import load_database
import re
from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import CalcNumRings


from copy import deepcopy


def remove_atom_mapping_from_reaction_smiles(reaction_smiles):
    """
    Remove atom mapping from a reaction SMILES string.
    
    Parameters:
    - reaction_smiles (str): A reaction SMILES string with atom mapping.
    
    Returns:
    - str: A reaction SMILES string without atom mapping.
    """
    # Split the reaction SMILES into its components (reactants, agents, products)
    parts = reaction_smiles.split('>>')
    
    # Remove atom mapping from each part
    cleaned_parts = [Chem.CanonSmiles(re.sub(r":\d+", "", part)) for part in parts]
    
    # Concatenate the cleaned parts back into a reaction SMILES string
    cleaned_reaction_smiles = '>>'.join(cleaned_parts)
    
    return cleaned_reaction_smiles



def calculate_chemical_properties(dictionary_list):
    updated_list = deepcopy(dictionary_list)  # Create a deep copy of the original list
    for entry in updated_list:
        reactant_smiles = entry['reactants']
        product_smiles = entry['products']

        # Initialize RDKit molecule objects from SMILES
        reactant_mol = Chem.MolFromSmiles(reactant_smiles)
        product_mol = Chem.MolFromSmiles(product_smiles)

        if reactant_mol is not None and product_mol is not None:
            # Calculate carbon difference
            num_carbon_reactants = sum([atom.GetAtomicNum() == 6 for atom in reactant_mol.GetAtoms()])
            num_carbon_products = sum([atom.GetAtomicNum() == 6 for atom in product_mol.GetAtoms()])
            entry['carbon_difference'] = abs(num_carbon_reactants - num_carbon_products)
            
            # Calculate total number of carbons
            entry['total_carbons'] = num_carbon_reactants + num_carbon_products

            # Calculate total number of bonds
            entry['total_bonds'] = abs(reactant_mol.GetNumBonds() - product_mol.GetNumBonds())

            # Calculate total number of rings
            entry['total_rings'] = abs(CalcNumRings(reactant_mol) - CalcNumRings(product_mol))
        else:
            entry['carbon_difference'] = "Invalid SMILES"
            entry['total_carbons'] = "Invalid SMILES"
            entry['total_bonds'] = "Invalid SMILES"
            entry['total_rings'] = "Invalid SMILES"



        # Process for fragment count calculation
        reactant_fragment_count = len(reactant_smiles.split('.'))
        product_fragment_count = len(product_smiles.split('.'))
        total_fragment_count = reactant_fragment_count + product_fragment_count
        entry['fragment_count'] = total_fragment_count

    return updated_list




def count_boundary_atoms_products_and_calculate_changes(list_of_dicts):
    for item in list_of_dicts:
        count = 0  # Initialize count for boundary_atoms_products
        # Initialize variables for bond and ring changes
        bond_change = 0
        ring_change = 0
        
        if 'boundary_atoms_products' in item and item['boundary_atoms_products']:
            for i in item['boundary_atoms_products']:
                if isinstance(i, dict):
                    count += 1
                elif isinstance(i, list):
                    for j in i:
                        if isinstance(j, dict):
                            count += 1
        
        # Split new_reactions into reactant and product SMILES and calculate changes
        
        reactant_product = item['new_reaction'].split('>>')
        if len(reactant_product) == 2:  # Ensure there are both reactant and product
            reactant_smiles, product_smiles = reactant_product
            reactant_mol = Chem.MolFromSmiles(reactant_smiles)
            product_mol = Chem.MolFromSmiles(product_smiles)
            
            if reactant_mol and product_mol:
                # Calculate bond change
                bond_change = abs(reactant_mol.GetNumBonds() - product_mol.GetNumBonds())
                # Calculate ring change
                ring_change = abs(CalcNumRings(reactant_mol) - CalcNumRings(product_mol))
        
        # Add calculated values to the dictionary
        item['num_boundary'] = count
        item['bond_change_merge'] = bond_change
        item['ring_change_merge'] = ring_change

    return list_of_dicts

In [2]:
import pandas as pd
import pandas as pd
from SynRBL.rsmi_utils import load_database
from IPython.display import clear_output
def process_and_combine_datasets(list_data, pipeline_path, data_path, remove_undetected=True):
    """
    Processes and combines datasets from specified paths.

    Parameters:
    - list_data (list): List of dataset names.
    - pipeline_path (str): Path to the pipeline files.
    - data_path (str): Path to the data files.

    Returns:
    - pd.DataFrame: Combined DataFrame of all processed datasets.
    """
    data_all = pd.DataFrame()
    
    for data_name in list_data:
        # Load dataset CSV and adjust columns
        data_csv_path = f'{pipeline_path}/Validation/Analysis/SynRBL - {data_name}.csv'
        data = pd.read_csv(data_csv_path).drop(['Note'], axis=1)
        #print(data.shape)
        data.loc[data['Result'] == 'CONSIDER', 'Result'] = False
        data.loc[data['Result'] == 'FALSE', 'Result'] = False
        data.loc[data['Result'] == 'TRUE', 'Result'] = True
        #data['Result'] =

        # Load and process additional data
        merge_data_path = f'{data_path}/Validation_set/{data_name}/MCS/MCS_Impute.json.gz'
        mcs_data_path = f'{data_path}/Validation_set/{data_name}/mcs_based_reactions.json.gz'
        
        merge_data = load_database(merge_data_path)
        #print(len(merge_data))
        merge_data = count_boundary_atoms_products_and_calculate_changes(merge_data)
        mcs_data = load_database(mcs_data_path)
        id = [value['R-id'] for value in merge_data]
        mcs_data = [value for value in mcs_data if value['R-id'] in id]
        mcs_data = calculate_chemical_properties(mcs_data)
        #print(len(mcs_data))
        #clear_output(wait=False)
        
        # Combine data
        combined_data = pd.concat([
            pd.DataFrame(mcs_data)[['R-id', 'reactions', 'carbon_difference', 'fragment_count', 'total_carbons', 'total_bonds', 'total_rings']],
            data,
            pd.DataFrame(merge_data)[['mcs_carbon_balanced', 'num_boundary', 'ring_change_merge', 'bond_change_merge']],
        ], axis=1)
        #print(combined_data.isnull().sum().sum())
        combined_data.loc[(combined_data['mcs_carbon_balanced'] == False) & (combined_data['Result'] == True), 'Result']=False
        if remove_undetected:
            combined_data = combined_data[combined_data['mcs_carbon_balanced'] == True]
        
        data_all = pd.concat([data_all, combined_data], axis=0)
    data_all = data_all.reset_index(drop=True)
    unnamed_columns = [col for col in data_all.columns if 'Unnamed' in col]
    data_all = data_all.drop(unnamed_columns, axis=1)

    return data_all


list_data = ['golden_dataset', 'Jaworski', 'USPTO_random_class', 'USPTO_diff', 'USPTO_unbalance_class']
pipeline_path = '../../../Pipeline'
data_path = '../../../Data'

data_total = process_and_combine_datasets(list_data, pipeline_path, data_path, remove_undetected=False)




In [3]:
from drfp import DrfpEncoder
rxn_smiles=data_total['reactions'].tolist()
fps = DrfpEncoder.encode(rxn_smiles)

In [4]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.pipeline import Pipeline
from xgboost import XGBClassifier
from sklearn.metrics import classification_report
from imblearn.combine import SMOTEENN
from imblearn.pipeline import Pipeline as Pipelinelit
X= fps
y = data_total['Result']

le = LabelEncoder()
y = le.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [6]:
steps = [('scaler', MinMaxScaler()), ('model', XGBClassifier(random_state=42))]
pipeline = Pipeline(steps=steps)

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)

# Print classification report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.50      0.20      0.28        87
           1       0.83      0.95      0.89       364

    accuracy                           0.81       451
   macro avg       0.67      0.57      0.58       451
weighted avg       0.77      0.81      0.77       451

