In [1]:
import jsonlines
import json

In [9]:
step_totals = {1: 0, 2: 0, 3: 0}

def get_steps(rxn_dict):
    subs = rxn_dict.get('subsequent mechanisms')
    if subs and isinstance(subs, list) and subs:
        return get_steps(subs[0]) + 1
    else:
        return 0

def get_stats(file):
    total_rxn = 0
    step_counts = {1: 0, 2: 0, 3: 0}
    with jsonlines.open(file, 'r') as db:
        for i in db:
            total_rxn += 1
            step = get_steps(i)
            if step in step_counts:
                step_counts[step] += 1
            else:
                step_counts[3] += 1  # Count steps > 2 as 3-step
    print("The total number of reactions in the dataset:", total_rxn, "Number of single step reactions in the dataset:", step_counts[1], "Number of two step reactions in the dataset:", step_counts[2], "Number of three step reactions in the dataset:", step_counts[3], sep='\n')
    for k in step_totals:
        step_totals[k] += step_counts[k]

get_stats("./raw_dataset/Good_Match.json")


The total number of reactions in the dataset:
21544
Number of single step reactions in the dataset:
538
Number of two step reactions in the dataset:
18038
Number of three step reactions in the dataset:
2968


In [15]:
import json
import jsonlines
from collections import Counter

def get_mechanism(rxn_dict, step):
    mech_list = []
    current = rxn_dict.get('subsequent mechanisms')[0]
    for _ in range(step):
        mech_list.append(current.get('Mechanism name'))
        current = current.get('subsequent mechanisms')[0] if current.get('subsequent mechanisms') else None
    return mech_list

def get_step(rxn_dict):
    if rxn_dict.get('subsequent mechanisms'):
        return get_step(rxn_dict['subsequent mechanisms'][0]) + 1
    return 0

def get_stats(path):
    step_counter = Counter()
    with jsonlines.open(path, 'r') as db:
        for i in db:
            step = get_step(i)
            step_counter[step] += 1
    for s in range(1, 4):
        print(f"Total number of {s} step reaction: {step_counter.get(s, 0)}")
        
def get_mechanism_stat(path):
    mech_counter = Counter()
    with open(path, 'r') as file:
        for line in file:
            rxn_dict = json.loads(line)
            mech = tuple(get_mechanism(rxn_dict, get_step(rxn_dict)))
            mech_counter[mech] += 1
    for mech, count in mech_counter.items():
        print({'mech': list(mech), 'count': count})
    print(f"The total number of reactions: {sum(mech_counter.values())}")
    print(f"Total number of mechanism types: {len(mech_counter)}\n")
    

get_mechanism_stat("./raw_dataset/Good_Match.json")
get_stats("./raw_dataset/Good_Match.json")


{'mech': ['Dissociation', 'Nucleophilic_addition', 'Isomerization'], 'count': 727}
{'mech': ['Deprotonation', 'SN2_Reaction'], 'count': 9149}
{'mech': ['Deprotonation', 'Nucleophilic_addition', 'Elimination_reaction'], 'count': 245}
{'mech': ['Deprotonation', 'SnAr_reaction'], 'count': 4163}
{'mech': ['Nucleophilic_addition', 'Deprotonation', 'Isomerization'], 'count': 284}
{'mech': ['Nucleophilic_addition', 'Deprotonation', 'Elimination_reaction'], 'count': 933}
{'mech': ['Deprotonation', 'Deprotonation', 'SnAr_reaction'], 'count': 73}
{'mech': ['Nucleophilic_addition', 'Protonation'], 'count': 118}
{'mech': ['Deprotonation', 'SN2_Reaction', 'Deprotonation'], 'count': 38}
{'mech': ['Deprotonation', 'Protonation'], 'count': 1314}
{'mech': ['SN2_Reaction', 'Deprotonation'], 'count': 188}
{'mech': ['Protonation', 'Protonation', 'Deprotonation'], 'count': 13}
{'mech': ['Nucleophilic_addition', 'Deprotonation'], 'count': 1926}
{'mech': ['Dissociation', 'SnAr_reaction'], 'count': 378}
{'mec

## Split training set, validation set and testing set

In [21]:
# shuffle and split the data into training, testing and validation set
import random


with open('./raw_dataset/Good_Match.json', 'r') as database:
    line = database.readlines()
    random.shuffle(line)
    training = line[:int(len(line)*0.8)]
    testing = line[int(len(line)*0.8):int(len(line)*0.9)]
    validation = line[int(len(line)*0.9):]
    with open('./raw_dataset/training_set_preprocess.json', 'a') as training_set:
        for i in training:
            training_set.write(i)
    with open('./raw_dataset/testing_set_preprocess.json', 'a') as testing_set:
        for i in testing:
            testing_set.write(i)
    with open('./raw_dataset/validation_set_preprocess.json', 'a') as validation_set:
        for i in validation:
            validation_set.write(i)

## Extract mechanism SMILES

In [22]:
def get_mechanism_smi(rxn_dict, step):
    mech_list = []
    rxn_dict = rxn_dict['subsequent mechanisms'][0]
    for i in range(0, step):
        mech_list.append(rxn_dict['Mechanism smi'])
        if rxn_dict.get('subsequent mechanisms'):
            rxn_dict = rxn_dict['subsequent mechanisms'][0]
        else:
            mech_list[-1] = mech_list[-1] + '_EOS'
            continue
    return mech_list

def get_rxn_smi(rxn_dict):
    return rxn_dict['Reaction']

### Retrosynthesis model data preparation

In [None]:
import os
import re
import json
import jsonlines
from rdkit.Chem import rdChemReactions

def preprocess_data(path, mapping=False):
    # Prepare input file path
    
    # Determine dataset type and set up directories/files
    dataset_types = ['training', 'testing', 'validation']
    dataset_type = next((dt for dt in dataset_types if dt in path), None)
    if not dataset_type:
        raise ValueError("Path must contain 'training', 'testing', or 'validation'.")

    base_dir = f"./processed_dataset/retrosynthesis{'_mapped' if mapping else ''}/{dataset_type}"
    os.makedirs(base_dir, exist_ok=True)
    sources_file = f"{base_dir}/{dataset_type[:5] if not dataset_type == 'testing' else 'test'}_sources.txt"
    targets_file = f"{base_dir}/{dataset_type[:5] if not dataset_type == 'testing' else 'test'}_targets.txt"

    # Ensure sources and targets files exist
    open(sources_file, 'a').close()
    open(targets_file, 'a').close()

    # Prepare statistics directory
    stats_dir = f"./Statistics/retrosynthesis"
    os.makedirs(stats_dir, exist_ok=True)
    file_suffix = '_mapped' if mapping else ''
    stats_file = f"./Statistics/retrosynthesis/reaction_location{file_suffix}.json"

    with jsonlines.open(path, 'r') as file:
        for rxn_dict in file:
            # Get mechanism SMILES
            mech = get_mechanism_smi(rxn_dict, get_step(rxn_dict))
            if not mapping:
                temp_mech = []
                for smi in mech:
                    end_of_string = '_EOS' in smi
                    smi = smi.replace('_EOS', '')
                    mol_temp = rdChemReactions.ReactionFromSmarts(smi)
                    rdChemReactions.RemoveMappingNumbersFromReactions(mol_temp)
                    smi = rdChemReactions.ReactionToSmiles(mol_temp)
                    if end_of_string:
                        smi += '_EOS'
                    temp_mech.append(smi)
                mech = temp_mech

            # Determine where to append new data
            with open(sources_file, 'r') as db:
                start_idx = len(db.readlines())
            location = list(range(start_idx, start_idx + len(mech)))

            # Write sources and targets
            for smi in mech:
                reactants, products = smi.split('>>')
                with open(sources_file, 'a') as src, open(targets_file, 'a') as tgt:
                    src.write(reactants.strip() + '\n')
                    tgt.write(products.strip() + '\n')

            # Record statistics
            rxn_dict['location'] = location
            with open(stats_file, 'a') as locate:
                json.dump(rxn_dict, locate)
                locate.write('\n')


In [27]:
preprocess_data('./raw_dataset/training_set_preprocess.json')
preprocess_data('./raw_dataset/testing_set_preprocess.json')
preprocess_data('./raw_dataset/validation_set_preprocess.json')

### Graph2SMILES data preparation

In [32]:
import pandas as pd
import os
import jsonlines
import rdkit
from rdkit import Chem
from rdkit.Chem import rdChemReactions


os.makedirs('./processed_dataset/g2s/raw', exist_ok=True)
def preprocess_data(path):
    # TODO: Makedir with mapped and unmapped data, and train model accordingly
    with jsonlines.open(path, 'r') as file:
        all_mech = []
        df = pd.DataFrame(columns=['rxn_smiles'])
        for rxn_dict in file:
            mech = get_mechanism_smi(rxn_dict, get_steps(rxn_dict))
            temp_mech = []
            for smi in mech:
                end_of_string = False
                if '_EOS' in smi:
                    end_of_string = True
                    smi = re.sub('_EOS', '', smi)
                mol_temp = rdChemReactions.ReactionFromSmarts(smi)
                rdChemReactions.RemoveMappingNumbersFromReactions(mol_temp)
                smi = rdChemReactions.ReactionToSmiles(mol_temp).split('>>')
                smi = '>>'.join([Chem.MolToSmiles(Chem.MolFromSmiles(i)) for i in smi])
                if end_of_string:
                    smi = smi + 'END'
                temp_mech.append(smi)
            mech = temp_mech
            for smi in mech:
                all_mech.append(smi)
            with open('Statistics/reaction_location_g2s.json', 'a') as locate:
                rxn_dict['location'] = [i for i in range(len(all_mech) - len(mech), len(all_mech))]
                json.dump(rxn_dict, locate)
                locate.write('\n')
    if 'training' in path:
        df['rxn_smiles'] = all_mech
        df.to_csv('./processed_dataset/g2s/raw/raw_train.csv', index_label='id')
    elif 'testing' in path:
        df['rxn_smiles'] = all_mech
        df.to_csv('./processed_dataset/g2s/raw/raw_test.csv', index_label='id')
    else:
        df['rxn_smiles'] = all_mech
        df.to_csv('./processed_dataset/g2s/raw/raw_val.csv', index_label='id')


def preprocess_data_rxn(path):
    if 'training' in path:
        with open('./processed_dataset/g2s/raw/raw_train.csv', 'w') as db:
            pass
    if 'testing' in path:
        with open('./processed_dataset/g2s/raw/raw_test.csv', 'w') as db:
            pass
    if 'validation' in path:
        with open('./processed_dataset/g2s/raw/raw_val.csv', 'w') as db:
            pass
    # TODO: Makedir with mapped and unmapped data, and train model accordingly
    with jsonlines.open(path, 'r') as file:
        all_mech = []
        df = pd.DataFrame(columns=['rxn_smiles'])
        for rxn_dict in file:
            mech = rxn_dict['Reaction']
            if "->" in mech:
                print(mech)
                continue
            mol_temp = rdChemReactions.ReactionFromSmarts(mech)
            rdChemReactions.RemoveMappingNumbersFromReactions(mol_temp)
            smi = rdChemReactions.ReactionToSmiles(mol_temp)
            smi = '>'.join([Chem.MolToSmiles(Chem.MolFromSmiles(i)) for i in smi.split('>')])
            all_mech.append(smi)

In [None]:
preprocess_data('./raw_dataset/training_set_preprocess.json')
preprocess_data('./raw_dataset/testing_set_preprocess.json')
preprocess_data('./raw_dataset/validation_set_preprocess.json')