In [None]:
! pip install rxnmapper
! pip install scikit-learn

In [7]:
import pandas as pd
from rdkit import Chem
from tqdm import tqdm
import re

from rdkit import RDLogger
from rxnmapper import RXNMapper
from sklearn.model_selection import train_test_split
RDLogger.DisableLog('rdApp.*')

In [3]:
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

def canno(smi):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))

def rxnF2onmt(reactionF, outpath, ttv="train",canno=False):
    assert ttv in ["train", "test", "val"]

    with open(reactionF,"r") as f:
        reactions = f.readlines()
        reactions = [i.rstrip() for i in reactions]
    
    srcF = open(f"{outpath}/src-{ttv}.txt","w")
    tgtF = open(f"{outpath}/tgt-{ttv}.txt","w")
    
    reactants = []  
    products = []

    for reaction in tqdm(reactions):

        try:
            reactant, product = reaction.split('>>')
            if canno:
                reactant = Chem.MolToSmiles(Chem.MolFromSmarts(reactant))
                product = Chem.MolToSmiles(Chem.MolFromSmarts(product))

            reactant = smi_tokenizer(reactant)
            reactants.append(reactant)
            product = smi_tokenizer(product)
            products.append(product)
        except:
            # print('wrong:', reaction)
            continue

    for i, v in enumerate(reactants):
        if products[i] != '' and reactants[i] != '':
            srcF.write(products[i] + '\n')
            tgtF.write(reactants[i] + '\n')

    srcF.close(), tgtF.close()

def rxnF2rtfm(file_path:str,output_path:str,ttv="train"):
    rxn_mapper = RXNMapper()
    assert ttv in ["train", "test", "val"]

    with open(file_path,"r") as f:
        dl = f.readlines()
        dl = [i.rstrip() for i in dl]
    targets = []
    results = [] 
    error = []
    for i, row in enumerate(tqdm(dl)):
        targets.append(row)
        try:
            results += rxn_mapper.get_attention_guided_atom_maps([row])
            if i % 10000 == 0:
                results_df = pd.DataFrame({'mapped_rxn': results})
                print()
                results_df['mapped_rxn'].to_csv(f"{output_path}/raw_{ttv}.csv")
        except:
            results += [{'mapped_rxn': f'{row}','confidence': 0}]
            error.append((i,row))

    results_df = pd.DataFrame(results)
    newdf = pd.DataFrame({"id":[i for i in range(len(results_df))],"class":["UNK" for i in range(len(results_df))],"reactants>reagents>production": results_df['mapped_rxn'],"origin_SMILES": targets})
    am_newdf = newdf[newdf['origin_SMILES']!= newdf['reactants>reagents>production']]
    am_newdf.to_csv(f"{output_path}/raw_{ttv}.csv")

In [None]:
# Read the lines from the original dataset
%cd example
with open('origin_dataset.txt', 'r') as file:
    lines = file.readlines()

# Shuffle the lines for randomness
import random
random.shuffle(lines)

# Calculate the number of lines for each set
total_lines = len(lines)
train_size = int(0.9 * total_lines)
test_size = int(0.05 * total_lines)

# Split the dataset
train_set, test_valid_set = train_test_split(lines, train_size=train_size, shuffle=False)
test_set, valid_set = train_test_split(test_valid_set, train_size=test_size, shuffle=False)

# Write the split datasets to separate files
with open('train.txt', 'w') as file:
    file.writelines(train_set)

with open('test.txt', 'w') as file:
    file.writelines(test_set)

with open('valid.txt', 'w') as file:
    file.writelines(valid_set)

! wc -l *.txt

In [12]:
target_path = '/path/to/READRetro/scripts/preprocessing/example'

data_train = f"{target_path}/train.txt"
with open(data_train) as f:
    ds_train = f.readlines()
    ds_train = [l.rstrip() for l in ds_train]

data_test = f"{target_path}/test.txt"
with open(data_test) as f:
    ds_test = f.readlines()
    ds_test = [l.rstrip() for l in ds_test]

data_valid = f"{target_path}/valid.txt"
with open(data_valid) as f:
    ds_valid = f.readlines()
    ds_valid = [l.rstrip() for l in ds_valid]

print(len(ds_train),len(ds_test),len(ds_valid))

cano_ds_train, cano_ds_test, cano_ds_valid = [],[],[]

for line in tqdm(ds_train):
    r,p = line.split(">>")
    try:
        cr,cp = canno(r), canno(p)
    except: continue
    cano_ds_train.append(f"{cr}>>{cp}")
    
for line in tqdm(ds_test):
    r,p = line.split(">>")
    try:
        cr,cp = canno(r), canno(p)
    except: continue
    cano_ds_test.append(f"{cr}>>{cp}")

for line in tqdm(ds_valid):
    r,p = line.split(">>")
    try:
        cr,cp = canno(r), canno(p)
    except: continue
    cano_ds_valid.append(f"{cr}>>{cp}")


print(len(cano_ds_train),len(cano_ds_test),len(cano_ds_valid))

drop_cano_ds_train = list(pd.DataFrame({"t":cano_ds_train}).drop_duplicates()['t'])
drop_cano_ds_test = list(pd.DataFrame({"t":cano_ds_test}).drop_duplicates()['t'])
drop_cano_ds_valid = list(pd.DataFrame({"t":cano_ds_valid}).drop_duplicates()['t'])
print(len(drop_cano_ds_train),len(drop_cano_ds_test),len(drop_cano_ds_valid))

with open(f"{target_path}/new_train.txt","w") as nf:
    for line in tqdm(drop_cano_ds_train):
        nf.write(line+"\n")

with open(f"{target_path}/new_test.txt","w") as nf:
    for line in tqdm(drop_cano_ds_test):
        nf.write(line+"\n")

with open(f"{target_path}/new_valid.txt","w") as nf:
    for line in tqdm(drop_cano_ds_valid):
        nf.write(line+"\n")

84557 4697 4699


100%|██████████| 84557/84557 [00:36<00:00, 2319.94it/s]
100%|██████████| 4697/4697 [00:02<00:00, 2281.41it/s]
100%|██████████| 4699/4699 [00:02<00:00, 2323.75it/s]


79885 4442 4458
79852 4442 4458


100%|██████████| 79852/79852 [00:00<00:00, 2433506.96it/s]
100%|██████████| 4442/4442 [00:00<00:00, 1965098.45it/s]
100%|██████████| 4458/4458 [00:00<00:00, 1724926.87it/s]


In [13]:
! mkdir $target_path/onmt
rxnF2onmt(f"{target_path}/new_train.txt",f'{target_path}/onmt',ttv='train')
rxnF2onmt(f"{target_path}/new_valid.txt",f'{target_path}/onmt',ttv='val')
rxnF2onmt(f"{target_path}/new_test.txt",f'{target_path}/onmt',ttv='test')

! cp -r $target_path/onmt $target_path/g2s

100%|██████████| 79852/79852 [00:01<00:00, 47513.76it/s]
100%|██████████| 4458/4458 [00:00<00:00, 52951.13it/s]
100%|██████████| 4442/4442 [00:00<00:00, 52956.49it/s]


In [14]:
! mkdir $target_path/retroformer
rxnF2rtfm(f"{target_path}/new_train.txt",f'{target_path}/retroformer',ttv="train")
rxnF2rtfm(f"{target_path}/new_valid.txt",f'{target_path}/retroformer',ttv="val")
rxnF2rtfm(f"{target_path}/new_test.txt",f'{target_path}/retroformer',ttv="test")

  0%|          | 11/79852 [00:01<1:44:17, 12.76it/s]




  0%|          | 110/79852 [00:02<17:53, 74.30it/s] Token indices sequence length is longer than the specified maximum sequence length for this model (649 > 512). Running this sequence through the model will result in indexing errors
 13%|█▎        | 10004/79852 [01:39<14:18, 81.39it/s]




 25%|██▌       | 20000/79852 [03:19<09:59, 99.86it/s] 




 38%|███▊      | 29999/79852 [05:00<07:51, 105.67it/s]




 50%|█████     | 39998/79852 [06:41<07:15, 91.57it/s] 




 63%|██████▎   | 49993/79852 [08:22<04:55, 101.18it/s]




 88%|████████▊ | 69990/79852 [11:42<01:40, 98.03it/s] 




100%|██████████| 79852/79852 [13:22<00:00, 99.50it/s] 
  0%|          | 9/4458 [00:00<00:54, 81.64it/s]




 74%|███████▎  | 3279/4458 [00:32<00:11, 103.95it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (800 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 4458/4458 [00:43<00:00, 101.41it/s]
  0%|          | 17/4442 [00:00<00:54, 81.75it/s]




  4%|▍         | 195/4442 [00:01<00:39, 106.84it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (555 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 4442/4442 [00:43<00:00, 101.37it/s]
