In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdMolDescriptors
from collections import Counter
import os, sys
import tqdm
import string
import pprint as pp
import numpy as np
import matplotlib.pyplot as plt
import pprint as pp
import pandas as pd

# shut up warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


MAX_REACTIONS = 100
ADD_Hs = False



# Preprocessing data and counting atoms of reactions molecules.

## Loading data in memory

In [None]:
#with open('./data/DataSet-USPTO-main/MIT_separated/src-train.txt', 'r') as src_fp:
#    src_lines = src_fp.readlines()   
#with open('./data/DataSet-USPTOw-main/MIT_separated/tgt-train.txt', 'r') as tgt_fp:
#    tgt_lines = tgt_fp.readlines() 

# populate dataset
dataset = {
    "TR": {
        "fp_src": open('./data/DataSet-USPTO-main/MIT_separated/src-train.txt', 'r'),
        "fp_tgt": open('./data/DataSet-USPTO-main/MIT_separated/tgt-train.txt', 'r'),
        "reagents" : [],
        "reactants": [],
        "products" : [],
    },
    "VL":{
        "fp_src": open('./data/DataSet-USPTO-main/MIT_separated/src-valid.txt', 'r'),
        "fp_tgt": open('./data/DataSet-USPTO-main/MIT_separated/tgt-valid.txt', 'r'),        
        "reagents" : [],
        "reactants": [],
        "products" : []        
    },
    "TS":{
        "fp_src": open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r'),
        "fp_tgt": open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r'),        
        "reagents" : [],
        "reactants": [],
        "products" : []        
    }
}

# load data in memory
for data_split in ["TR", "VL", "TS"]:
    src_lines = dataset[data_split]["fp_src"].readlines()
    tgt_lines = dataset[data_split]["fp_tgt"].readlines()
    dataset[data_split]["fp_src"].close()
    dataset[data_split]["fp_tgt"].close()
    
    remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.
    for src_line, tgt_line in tqdm.tqdm(zip(src_lines, tgt_lines), total=len(src_lines)):

    #for src_line, tgt_line in tqdm.tqdm(zip(src_lines[:MAX_REACTIONS], tgt_lines[:MAX_REACTIONS]), total=len(src_lines[:MAX_REACTIONS])):
        #line_splits = line.split(" ")
        #assert(len(line_splits) == 2)
        #reaction, bond_changes = line_splits[0], line_splits[1]

        products_nonsplit = tgt_line.translate(remove_whitespaces)

        src_line_splits = src_line.split(">")
        assert(len(src_line_splits) == 2)
        reactants_nonsplit, reagents_nonsplit = src_line_splits[0].translate(remove_whitespaces), src_line_splits[1].translate(remove_whitespaces)

        reactants = reactants_nonsplit.split(".")
        reagents  = reagents_nonsplit.split(".")
        products  = products_nonsplit.split(".")
        
        # sometimes reagents are empty, remove empty strings
        if reagents == [""]:
            #print("reagents are empty string!")
            reagents = []
                
        if ADD_Hs:
            dataset[data_split]["reactants"].append([Chem.AddHs(Chem.MolFromSmiles(reactant)) for reactant in reactants ])
            dataset[data_split]["reagents"].append([Chem.AddHs(Chem.MolFromSmiles(reagent)) for reagent in reagents ])
            dataset[data_split]["products"].append([Chem.AddHs(Chem.MolFromSmiles(product)) for product in products ])
        else:
            dataset[data_split]["reactants"].append([Chem.MolFromSmiles(reactant) for reactant in reactants ])
            dataset[data_split]["reagents"].append([Chem.MolFromSmiles(reagent) for reagent in reagents ])
            dataset[data_split]["products"].append([Chem.MolFromSmiles(product) for product in products ])

    assert len(dataset[data_split]["reactants"]) == len(dataset[data_split]["reagents"]) == len(dataset[data_split]["products"])
    dataset[data_split]["len"] = len(dataset[data_split]["reactants"])


## How many reactions are unbalanced w.r.t. total?

In [None]:

def compute_atom_count(molecules, add_hs=False):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        
        if add_hs:
            mol = Chem.addHs(mol)
        
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms
    
for data_split in ["TR", "VL", "TS"]:
    counter_final = Counter()
    counter_final_diff = Counter()
    multiple_products_reactions =  0
    
    rdkit_reactants = dataset[data_split]["reactants"]
    rdkit_reagents  = dataset[data_split]["reagents"]
    rdkit_products  = dataset[data_split]["products"]
    
    assert len(rdkit_reactants) == len(rdkit_reagents) == len(rdkit_products)
    
    rxn_tot = len(rdkit_reactants)
    # for each reactions, compute number of atoms on lhs and rhs.
    for rxn_i in tqdm.tqdm(range(rxn_tot)):
        
        if len(rdkit_products[rxn_i]) >= 2:
            multiple_products_reactions += 1
        
        reactants_counter, reagents_counter, products_counter = Counter(), Counter(), Counter()
        reactants_num_atoms, reagents_num_atoms, products_num_atoms = 0, 0, 0
        
        c, na = compute_atom_count(rdkit_reactants[rxn_i])
        reactants_counter += c
        reactants_num_atoms += na
        
        c, na = compute_atom_count(rdkit_reagents[rxn_i])
        reagents_counter += c
        reagents_num_atoms += na
        
        c, na = compute_atom_count(rdkit_products[rxn_i])
        products_counter += c
        products_num_atoms += na
        
        #print(reactants_counter, ">", reagents_counter,">", products_counter)
        #print(reactants_num_atoms, ">", reagents_num_atoms, ">", products_num_atoms)
        #print("*** --- ***")
        
        
        counter_final_diff.update([reactants_num_atoms - products_num_atoms])    
        counter_final.update([str(reactants_counter == products_counter)])
        
        if False: #reactants_counter != products_counter:
            
            print("reactant counter:")
            pp.pprint(reactants_counter)#

            print("reagent counter:")
            pp.pprint(reagents_counter)#

            print("products counter:")
            pp.pprint(products_counter)

            print("diff counter (only reactants and products):")
            diff_atom_counter = reactants_counter - products_counter
            pp.pprint(diff_atom_counter)
            
            diff_atom_counter = reactants_counter - products_counter - Counter()
            pp.pprint(diff_atom_counter)
            
    
    multiple_products_rate = multiple_products_reactions / rxn_tot
    
    print(data_split, "- FINAL COUNT:", counter_final)
    print(data_split, "- ratio of blanced reactions:", counter_final["True"] / counter_final["False"])
    #print(data_split, "- DIFF COUNT:", counter_final_diff)
    print(data_split, "multiple products rxns :", multiple_products_reactions)    
    print(data_split, "multiple products rxns rate:", multiple_products_rate)
    
    # plotting histogram of average difference atoms
    plt.figure(figsize=(20, 10))
    labels, values = zip(*counter_final_diff.items())
    indexes = np.array(labels)
    width = 1
    plt.bar(indexes, values, width)
    plt.xticks(indexes, labels)
    plt.title(data_split + " - atom difference distribution")
    plt.show()
    

# Chech distribution shift (frequency differences for TR, VL TS of various molecules)

In [None]:

all_molecules_set = set()

if False:
    molecules_counters = {
        "TR": {
            "reactants": Counter(),
            "reagents": Counter(),
            "products": Counter()
        },
        "VL": {
            "reactants": Counter(),
            "reagents": Counter(),
            "products": Counter()
        },
        "TS": {
            "reactants": Counter(),
            "reagents": Counter(),
            "products": Counter()
        }
    }

molecules_counters = {
    "TR": Counter(),
    "VL": Counter(),
    "TS": Counter()
}

# first part: count all molecules and relative occurences

print("first part: count all molecules and relative occurences")
for data_split in ["TR", "VL", "TS"]:
    counter_final = Counter()
    counter_final_diff = Counter()
    multiple_products_reactions =  0
    
    rdkit_reactants = dataset[data_split]["reactants"]
    rdkit_reagents  = dataset[data_split]["reagents"]
    rdkit_products  = dataset[data_split]["products"]
    assert len(rdkit_reactants) == len(rdkit_reagents) == len(rdkit_products)
    
    rxn_tot = len(rdkit_reactants)
    
    for i in tqdm.tqdm(range(rxn_tot)):
        
        for kind in ["reactants", "reagents", "products"]:
            rdkit_molecules = dataset[data_split][kind]
        
            for mol in rdkit_molecules[i]:
                canon_mol = Chem.CanonSmiles(Chem.MolToSmiles(mol))
                
                if canon_mol == "":
                    print(i, "canon SMILES is an empty string!")
                    print("mol.object:", mol)
                    print("non-canon:", Chem.MolToSmiles(mol))
                    input("...")
                
                all_molecules_set.add(canon_mol)
                molecules_counters[data_split].update([canon_mol])           

#print(sorted(all_molecules_set))


In [None]:

# second part: for each molecule, compute occurence ratio for each data split
occurence_freqs = {}
occurence_diffs_counter = Counter()
only_tr_counter = Counter()

if True:
    occurence_ratios = {}                
    print("second part: for each molecule, compute occurence ratio for each data split")
    for mol in tqdm.tqdm(all_molecules_set):
        
        tr_occs =  molecules_counters["TR"].get(mol, 0)
        vl_occs =  molecules_counters["VL"].get(mol, 0)
        ts_occs =  molecules_counters["TS"].get(mol, 0)

        occurence_freqs[mol] = {
            "TR": tr_occs,
            "VL": vl_occs,
            "TS": ts_occs,
            "diff": tr_occs + vl_occs - ts_occs
        }
        only_tr = ((tr_occs + vl_occs) > 0) and (ts_occs <= 0)
        
        occurence_diffs_counter.update([occurence_freqs[mol]["diff"]])
        only_tr_counter.update([str(only_tr)])
    
    #print("Different occurences between TR+VL and TS for all molecules:")
    pp.pprint(occurence_diffs_counter.most_common(30))
    
    print("Molecules that appaear only in TR set:")
    pp.pprint(only_tr_counter)

    #plt.figure(figsize=(20, 10))
    
    #tmp = np.array(occurence_diffs_counter.items()) #most_common(1000))
    #labels, values = tmp[0, :], list(map(int, tmp[1, :]))
    
    #labels, values = zip(*occurence_diffs_counter.items())
    #indexes = np.array(labels)
    #width = 10000
    #plt.bar(indexes, values, width)
    #plt.xticks(indexes, labels)
    #plt.title(data_split + " - occurence diff between TR+VL and TS")
    #plt.show()


# Check the number of "alchemic reactions" predicted by Molecular transformer, Augmented Transformer and Chemformer

In [None]:
HEAVY_ATOMS_ONLY = True

## Molecular transformer

In [None]:
def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]

    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

#"./molecular_transformer_data/results/predictions_MIT_separated_augm_average_20_on_MIT_test"

with open('./molecular_transformer_data/results/predictions_MIT_separated_augm_ensemble_average_2x20av_on_MIT_test.txt', 'r') as fp:
    moltrans_pred_lines = fp.readlines()    

# first preprocessing: reshape the prediction so that each line contains 5 predictions.

moltrans_pred_lines = np.array(moltrans_pred_lines).reshape((-1, 5))

print("predictions len:", moltrans_pred_lines.shape[0])
print("target len:", dataset["TS"]["len"])

assert moltrans_pred_lines.shape[0] == dataset["TS"]["len"]

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = 40000 # * 5
total_valid_predictions = 0

c = 0
for i in tqdm.tqdm(range(moltrans_pred_lines.shape[0]), total=moltrans_pred_lines.shape[0]):
    c += 1
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    predictions = [ Chem.MolFromSmiles(moltrans_pred_lines[i, j].translate(remove_whitespaces)) for j in range(moltrans_pred_lines.shape[1]) ]
    
    
    if ADD_Hs:
        predictions = list(map(lambda x: Chem.AddHs(x), filter(lambda x: x is not None, predictions)))
    else:
        predictions = list(map(lambda x: x, filter(lambda x: x is not None, predictions)))    
    
    # TOP-1 prediction
    predictions = predictions[:1]
    total_valid_predictions += len(predictions)
        
    reactants = dataset["TS"]["reactants"][i]
    
    # if considering only heavy atoms, remove implicit Hs before computeing metrics
    if HEAVY_ATOMS_ONLY:
        predictions = [ Chem.RemoveHs(pred) for pred in predictions ]
        reactants = [ Chem.RemoveHs(react) for react in reactants ]
    else:
        predictions = [ Chem.AddHs(pred) for pred in predictions ]
        reactants = [ Chem.AddHs(react) for react in reactants ]    
  
    reactants_counter, reactants_n_atoms = atom_count(reactants)
    
    predictions_counters, predictions_n_atoms = [], []
    for pred in predictions:
            
        pred_counter, pred_n_atoms = atom_count([pred])
        
        predictions_counters.append(pred_counter)
        predictions_n_atoms.append(pred_n_atoms)        
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, predictions_counters)
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Total valid predictions:", total_valid_predictions)

valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] /  total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


## Augmented Transformer

In [None]:
# copy-paste the same code of molecular transformer, but need to change the initial preprocessing step

import pandas as pd

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

#"./molecular_transformer_data/results/predictions_MIT_separated_augm_average_20_on_MIT_test"

df_test_pred = pd.read_csv("./Augmented_Transformer_master/uspto-mit/sep/result_ibm-test20_from_100.csv")

# first preprocessing: reshape the prediction so that each line contains 5 predictions.
total_balanced_counter = Counter()
total_too_few_counter  = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()
total_valid_predictions = 0
total_predictions = df_test_pred.shape[0] # * 10

print(df_test_pred.loc[561, "target_10"])

for idx, row in tqdm.tqdm(df_test_pred.iterrows(), total=df_test_pred.shape[0]):
    # first part: preprocessing
    input_split = row["input"].split(">")
    
    #print("input_split")
    #pp.pprint(input_split)
    
    reactants = input_split[0]
    reagents  = input_split[1]
    
    predictions = [ row["target"] if n == 1 else row["target_"+str(n)] for n in range(1, 11)]
    
    try:
        reactants = Chem.MolFromSmiles(reactants)
        predictions = [ Chem.MolFromSmiles(pred) for pred in predictions if type(pred) == str]
    except TypeError:
        print(predictions)
        raise TypeError
    
    if ADD_Hs:
        reactants = Chem.AddHs(reactants)
        predictions = list(map(Chem.AddHs, filter(lambda x: x is not None, predictions)))
    else:
        predictions = list(filter(lambda x: x is not None, predictions))
    
    predictions = predictions[:1]
    total_valid_predictions += len(predictions)
    
    # if considering only heavy atoms, remove implicit Hs before computeing metrics
    if HEAVY_ATOMS_ONLY:
        new_predictions = []
        for pred in predictions:
            try:
                new_predictions.append(Chem.RemoveHs(pred))
            except:
                print("EXCEPTION PRED:", Chem.MolToSmiles(react))
                continue
                
        predictions = new_predictions
              
        reactants = Chem.RemoveHs(reactants)
    else:
        predictions = [ Chem.AddHs(pred) for pred in predictions ]
        reactants = Chem.AddHs(reactants)

    # second part: count atoms and alchemy predictions
    reactants_counter, reactants_n_atoms = atom_count([reactants])
    
    predictions_counters, predictions_n_atoms = [], []
    
    for pred in predictions:
        pred_counter, pred_n_atoms = atom_count([pred]) 
        predictions_counters.append(pred_counter)
        predictions_n_atoms.append(pred_n_atoms)
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, predictions_counters)
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
       

print("- Total predictions:", total_predictions) 
print("- Total valid predictions:", total_valid_predictions)

valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


## Chemformer

In [None]:
import string
import pandas as pd

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds


remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

df_test_data = pd.read_pickle("./chemformer_data/uspto_sep.pickle")
df_test_pred = pd.read_pickle("./MOLBART_LAST_TIME/uspto_sep_out.pickle")

df_test_data = df_test_data[df_test_data["set"] == "test"]

print("data shape:", df_test_data.shape, ". pred shape:", df_test_pred.shape)

assert df_test_data.shape[0] == df_test_pred.shape[0]


# first preprocessing: reshape the prediction so that each line contains 5 predictions.
total_balanced_counter = Counter()
total_too_few_counter  = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()
total_valid_predictions = 0
total_predictions = df_test_pred.shape[0] # * 10

for (idx_data, row_data), (idx_pred, row_pred) in tqdm.tqdm(zip(df_test_data.iterrows(), df_test_pred.iterrows()), total=df_test_pred.shape[0]):
    
    predictions = [ row_pred["prediction_"+str(n)] for n in range(0, 10)]
    
    try:
        reactants = row_data["reactants_mol"]
        predictions = [ Chem.MolFromSmiles(pred) for pred in predictions if type(pred) == str]
    except TypeError as err:
        print(err)
        print(predictions)
        break
    
    if ADD_Hs:
        reactants = Chem.AddHs(reactants)
        predictions = list(map(Chem.AddHs, filter(lambda x: x is not None, predictions)))
    else:
        predictions = list(filter(lambda x: x is not None, predictions))
    
    predictions = predictions[:1]
    total_valid_predictions += len(predictions)
    
    # if considering only heavy atoms, remove implicit Hs before computeing metrics
    if HEAVY_ATOMS_ONLY:
        predictions = [ Chem.RemoveHs(pred) for pred in predictions ]
        reactants = Chem.RemoveHs(reactants)
    else:
        predictions = [ Chem.AddHs(pred) for pred in predictions ]
        reactants = Chem.RemoveHs(reactants)

    # second part: count atoms and alchemy predictions
    reactants_counter, reactants_n_atoms = atom_count([reactants])
    
    predictions_counters, predictions_n_atoms = [], []
    
    for pred in predictions:
        pred_counter, pred_n_atoms = atom_count([pred]) 
        predictions_counters.append(pred_counter)
        predictions_n_atoms.append(pred_n_atoms)
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, predictions_counters)
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
print("- Total predictions:", total_predictions) 
print("- Total valid predictions:", total_valid_predictions)


valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


## Graph2Smiles

In [None]:
import string
import numpy as np

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

#"./molecular_transformer_data/results/predictions_MIT_separated_augm_average_20_on_MIT_test"

with open('./GRAPH2SMILES_predictions/results/USPTO_480k_g2s_series_rel_smiles_smiles.final_run_dgat.result.txt', 'r') as fp:
    moltrans_pred_lines = fp.readlines()


moltrans_pred_lines = np.array(moltrans_pred_lines)#.reshape((-1, 5))

print("predictions len:", moltrans_pred_lines.shape[0] * 10)
print("target len:", dataset["TS"]["len"])

assert moltrans_pred_lines.shape[0] == dataset["TS"]["len"]

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = 40000 # * 10
total_valid_predictions = 0

c = 0
for i in tqdm.tqdm(range(moltrans_pred_lines.shape[0]), total=moltrans_pred_lines.shape[0]):
    c += 1
    
    # keep only top-10 predictions
    smiles_preds = moltrans_pred_lines[i].translate(remove_whitespaces).split(",")[:10]
    predictions = [ Chem.MolFromSmiles(pred) for pred in smiles_preds ]
    
    predictions = list(map(lambda x: x, filter(lambda x: x is not None, predictions)))    
    
    predictions = predictions[:1]
    total_valid_predictions += len(predictions)
    
    reactants = dataset["TS"]["reactants"][i]
    
    # if considering only heavy atoms, remove implicit Hs before computeing metrics
    if HEAVY_ATOMS_ONLY:        
        new_predictions = []
        for pred in predictions:
            try:
                new_predictions.append(Chem.RemoveHs(pred))
            except:
                print("EXCEPTION PRED:", Chem.MolToSmiles(react))
                continue
                
        predictions = new_predictions

        new_reactants = []
        for react in reactants:
            try:
                new_reactants.append(Chem.RemoveHs(react))
            except:
                print("EXCEPTION REACT:", Chem.MolToSmiles(react))
                assert False
                
        reactants = new_reactants
        
    else:
        predictions = [ Chem.AddHs(pred) for pred in predictions ]
        reactants = [ Chem.AddHs(react) for react in reactants ]    
    
    reactants_counter, reactants_n_atoms = atom_count(reactants)
    
    predictions_counters, predictions_n_atoms = [], []
    for pred in predictions:
            
        pred_counter, pred_n_atoms = atom_count([pred])
        
        predictions_counters.append(pred_counter)
        predictions_n_atoms.append(pred_n_atoms)
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, predictions_counters)
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Total valid predictions:", total_valid_predictions)

valid_ratio = total_valid_predictions / total_predictions

try:
    balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
except ZeroDivisionError:
    balanced_ratio = 0.0
try:
    too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
except ZeroDivisionError:
    too_few_ratio = 0.0
try:
    too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
except ZeroDivisionError:
    too_many_ratio = 0.0
try:
    too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))
except ZeroDivisionError:
    too_few_many_ratio = 0.0
    

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


# Plot some alchemic reactions produced by the models

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdDepictor
rdDepictor.SetPreferCoordGen(True)
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG
import rdkit
print(rdkit.__version__)

size = (120, 120)
m = Chem.MolFromSmiles('C[C@@H](CC1=CC=CC=C1)NC')
fig = Draw.MolToMPL(m, size=size)


## Molecular Transformer

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

HEAVY_ATOMS_ONLY = True

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds


def moltosvg(mol, size=(300,300), kekulize=True):
    molSize = size
    mc = Chem.Mol(mol.ToBinary())
    if kekulize:
        try:
            Chem.Kekulize(mc)
        except:
            mc = Chem.Mol(mol.ToBinary())
    if not mc.GetNumConformers():
        rdDepictor.Compute2DCoords(mc)
    drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1])
    drawer.DrawMolecule(mc)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    return svg.replace('svg:','')


def drawMyMol(mol, fname, image_size, myFontSize):
    d = rdMolDraw2D.MolDraw2DCairo(image_size[0], image_size[1])
    d.SetFontSize(myFontSize)
    #print(d.FontSize())
    d.DrawMolecule(mol)
    d.FinishDrawing()
    d.WriteDrawingText(fname)
    

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

#"./molecular_transformer_data/results/predictions_MIT_separated_augm_average_20_on_MIT_test"

with open('./molecular_transformer_data/results/predictions_MIT_separated_augm_ensemble_average_2x20av_on_MIT_test.txt', 'r') as fp:
    moltrans_pred_lines = fp.readlines()    

# first preprocessing: reshape the prediction so that each line contains 5 predictions.

moltrans_pred_lines = np.array(moltrans_pred_lines).reshape((-1, 5))

print("predictions len:", moltrans_pred_lines.shape[0])
print("target len:", dataset["TS"]["len"])

assert moltrans_pred_lines.shape[0] == dataset["TS"]["len"]

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = 40000 # * 5
total_valid_predictions = 0

c = 0
os.makedirs("./plots/", exist_ok=True)
    
for i in tqdm.tqdm(range(moltrans_pred_lines.shape[0]), total=moltrans_pred_lines.shape[0]):
    c += 1
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    predictions = [ Chem.MolFromSmiles(moltrans_pred_lines[i, j].translate(remove_whitespaces)) for j in range(moltrans_pred_lines.shape[1]) ]
    
    
    if ADD_Hs:
        predictions = list(map(lambda x: Chem.AddHs(x), filter(lambda x: x is not None, predictions)))
    else:
        predictions = list(map(lambda x: x, filter(lambda x: x is not None, predictions)))    
    
    # TOP-1 prediction
    predictions = predictions[:1]
    total_valid_predictions += len(predictions)
        
    reactants = dataset["TS"]["reactants"][i]
    reagents = dataset["TS"]["reagents"][i]
    products = dataset["TS"]["products"][i]
    
    # if considering only heavy atoms, remove implicit Hs before computeing metrics
    if HEAVY_ATOMS_ONLY:
        predictions = [ Chem.RemoveHs(pred) for pred in predictions ]
        reactants = [ Chem.RemoveHs(react) for react in reactants ]
        
    else:
        predictions = [ Chem.AddHs(pred) for pred in predictions ]
        reactants = [ Chem.AddHs(react) for react in reactants ]    
  
    reactants_counter, reactants_n_atoms = atom_count(reactants)
    
    predictions_counters, predictions_n_atoms = [], []
    for pred in predictions:
            
        pred_counter, pred_n_atoms = atom_count([pred])
        
        predictions_counters.append(pred_counter)
        predictions_n_atoms.append(pred_n_atoms)        
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, predictions_counters)
    
    k = 0
    for pred, alchemic in zip(predictions, too_few_many_preds):
        k += 1
        if alchemic:
            reactants_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(react)) for react in reactants ])
            reactants_mol = Chem.MolFromSmiles(reactants_smiles)
            
            reagents_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(reag)) for reag in reagents ])
            reagents_mol = Chem.MolFromSmiles(reagents_smiles)
            
            products_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(prod)) for prod in products ])
            products_mol = Chem.MolFromSmiles(products_smiles)
            
            pred_smiles = Chem.MolToSmiles(pred)
            
            clear_output(wait=True)
            
            reacts_imgs = []
            
            IMAGE_sz = 500
            MPL_sz = 200
            FONTSIZE = 40
            

            
            keep = "y"
            if keep == "y":
                dirpath = "./plots/" + str(c) + "_" + str(k) + "/"
                os.makedirs(dirpath, exist_ok=True)
                
                #for i, img in enumerate(reacts_imgs):
                #    img.save(os.path.join(dirpath, "react_" + str(i) + ".png"))
                
                for i, react in enumerate(reactants):
                    drawMyMol(react,  os.path.join(dirpath, "react_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)
                
                
                for i, reag in enumerate(reagents):
                    drawMyMol(reag,  os.path.join(dirpath, "reag_" + str(i) +".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)
                
                
                drawMyMol(pred,  os.path.join(dirpath, "pred_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)
                
                for i, prod in enumerate(products):
                    drawMyMol(prod,  os.path.join(dirpath, "ground_truth_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz),  myFontSize=FONTSIZE)
                
                #reag_img.save(os.path.join(dirpath,"reag.png"))
                #pred_img.save(os.path.join(dirpath,"pred.png"))
                #gt_img.save(os.path.join(dirpath, "ground_truth.png"))
                
                with open(os.path.join(dirpath, "smiles.txt"), "w") as fp:
                    fp.write(
                        "REACTANTS   :" + reactants_smiles + "\n" + 
                        "REAGENTS    :" + reagents_smiles + "\n" +
                        "PREDICTION  :" + pred_smiles + "\n" +
                        "GROUND TRUTH:" + products_smiles + "\n"
                    )
                #break
    #if c > 5500:
    #    break
            
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Total valid predictions:", total_valid_predictions)


valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] /  total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


# See the performance of pretrained previous models to predict (doubled) reactions

In [None]:
ONLY_HEAVY_ATOMS = True

## Create a new MIT separated dataset where the reactants, reagents and products are doubled

### Separated

In [None]:
import os
    
with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_reacts = []
doubled_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    react_splits = react_line[:-1].split(">")
    
    assert len(react_splits) == 2
    
    reactants = react_splits[0] + " . " + react_splits[0]
    reagents = react_splits[1] + " . " + react_splits[1] + "\n"
    products = prod_line[:-1] + " . " + prod_line[:-1] + "\n"
    
    src_str  = reactants + " > " + reagents
    prod_str = products
    
    src_str = src_str.replace("  ", " ")
    prod_str = prod_str.replace("  ", " ")

    src_str = src_str.replace(" \n", "\n")
    src_str = src_str.replace("> .\n", ">\n")
    prod_str = prod_str.replace(" \n", "\n")
    
    doubled_reacts.append(src_str)
    doubled_prods.append(prod_str)

    
if not os.path.exists("./double_uspto_mit_separated/"):
    os.makedirs("./double_uspto_mit_separated/")

with open('./double_uspto_mit_separated/src-test.txt', 'w') as fp_src:
    with open('./double_uspto_mit_separated/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_reacts)
        fp_tgt.writelines(doubled_prods)

print("DONE")

### Combined

In [None]:
import os
from rdkit import Chem
    
with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_reacts = []
doubled_prods  = []

valid = 0
for react_line, prod_line in zip(reactants_lines, products_lines):
    react_splits = react_line[:-1].split(">")
    
    assert len(react_splits) == 2
    
    reactants = react_splits[0] + " . " + react_splits[0]
    reagents = react_splits[1] + " . " + react_splits[1] + "\n"
    products = prod_line[:-1] + " . " + prod_line[:-1] + "\n"
    
    if reagents == " . \n":
        src_str  = reactants + "\n"  
    else:
        src_str  = reactants + " . " + reagents
    
    prod_str = products
   
    src_str = src_str.replace("  ", " ")
    prod_str = prod_str.replace("  ", " ")

    src_str = src_str.replace(" \n", "\n")
    src_str = src_str.replace("> .\n", ">\n")
    prod_str = prod_str.replace(" \n", "\n")

    if True: # only for nowhitespace dataset
        for _ in range(5):
            src_str = src_str.replace("  ", " ")
            prod_str = prod_str.replace("  ", " ")
        
    # sanity check, make sure that doubled smiles form a rdkit Mol
    
    src_mol = Chem.MolFromSmiles(src_str)
    tgt_mol = Chem.MolFromSmiles(prod_str)
    
    if src_mol is None:
        pass
        #print("ERROR SRC MOL:", src_str)
        #continue
        
    if tgt_mol is None:
        pass
        #print("ERROR TGT MOL:", prod_str)
        #continue
    
    valid += 1  
    
    doubled_reacts.append(src_str)
    doubled_prods.append(prod_str)

    
if not os.path.exists("./double_uspto_mit_combined/"):
    os.makedirs("./double_uspto_mit_combined/")

with open('./double_uspto_mit_combined/src-test.txt', 'w') as fp_src:
    with open('./double_uspto_mit_combined/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_reacts)
        fp_tgt.writelines(doubled_prods)
        
print("Number of valid molecules:", valid, "/ 40000", "invalid:", 40000-valid)
print("DONE")

## Collect MolecularTransformer predictions on double_mit_dataset and compute accuracy

In [None]:

TYPE = "balanced" # "all"

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

# check whether the model is able to predict at least 1x of products
def is_half_accurate(pred_canonsmiles, true_canonsmiles):
    true_smiles_split = true_canonsmiles.split(".")
    true_smiles_split_half = ".".join(set(true_smiles_split))
    true_canonsmiles_split_half = Chem.CanonSmiles(true_smiles_split_half)
    return true_canonsmiles_split_half == pred_canonsmiles

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

assert TYPE == "one" or TYPE == "all" or TYPE == "balanced"

if TYPE == "all":
    with open('./double_uspto_mit_separated/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./double_uspto_mit_separated/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./MOL_T/predictions_MIT_separated_augm_model_average_20.pt_on_double_uspto_mit_separated_test.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()
elif TYPE == "one":
    with open('./2xUSPTO-MIT_one_separated/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./2xUSPTO-MIT_one_separated/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./MOL_T/predictions_MIT_separated_augm_model_average_20.pt_on_double_uspto_mit_separated_test.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()  
else:
    with open('./USPTO_balanced/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./USPTO_balanced/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./molecular_transformer_data/predictions_MIT_mixed_augm_model_average_20_on_USPTO_balanced_tokenized_test.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()      
        
assert len(moltrans_pred_lines) == len(moltrans_target_lines)

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = len(moltrans_pred_lines)
total_valid_predictions = 0
total_accurate_predictions = 0
total_half_accurate_predictions = 0
c = 0
for source, target, prediction in tqdm.tqdm(zip(moltrans_source_lines, moltrans_target_lines, moltrans_pred_lines), total=total_predictions):
    c += 1
    
    # if type is balanced, then the version is combined
    if TYPE == "balanced":
        source += " > "
    
    source_splits = source.split(">")
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    reactant = Chem.MolFromSmiles(source_splits[0].translate(remove_whitespaces))
    reagent = Chem.MolFromSmiles(source_splits[1].translate(remove_whitespaces))
    target = Chem.MolFromSmiles(target.translate(remove_whitespaces))
    prediction = Chem.MolFromSmiles(prediction.translate(remove_whitespaces))
    
    if prediction is not None and target is not None:
        total_valid_predictions += 1
        if ONLY_HEAVY_ATOMS:
            reactant = Chem.RemoveHs(reactant)
            reagent = Chem.RemoveHs(reagent)
            prediction = Chem.RemoveHs(prediction)
            target = Chem.RemoveHs(target)
        else:
            reactant = Chem.AddHs(reactant)
            reagent = Chem.AddHs(reagent)
            prediction = Chem.AddHs(prediction)
            target = Chem.AddHs(target)
    else:
        continue
        
    # compute top-1 accuracy
    target_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(target))
    prediction_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(prediction))
    
    if target_canonsmiles == prediction_canonsmiles:
        total_accurate_predictions += 1
    
    if is_half_accurate(prediction_canonsmiles, target_canonsmiles):
        total_half_accurate_predictions += 1
    
    reactants_counter, reactants_n_atoms = atom_count([reactant])
    prediction_counter, prediction_n_atoms = atom_count([prediction])
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, [prediction_counter])
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Valid predictions:", total_valid_predictions)
print("- Accurate predictions:", total_accurate_predictions)
print("- Half accurate predictions:", total_half_accurate_predictions)

accuracy = total_accurate_predictions / total_predictions
half_accuracy = total_half_accurate_predictions / total_predictions
valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of accurate prediction is:", accuracy)
print("The ratio of half-accurate prediction is:", half_accuracy)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


## Collect Chemformer predictions on double_mit_dataset and compute accuracy

In [None]:
import string
import pandas as pd
from collections import Counter
import tqdm
from rdkit import Chem

# shut up warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

ONLY_HEAVY_ATOMS = True
TYPE = "balanced"

def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

# check whether the model is able to predict at least 1x of products
def is_half_accurate(pred_canonsmiles, true_canonsmiles):
    true_smiles_split = true_canonsmiles.split(".")
    true_smiles_split_half = ".".join(set(true_smiles_split))
    true_canonsmiles_split_half = Chem.CanonSmiles(true_smiles_split_half)
    return true_canonsmiles_split_half == pred_canonsmiles

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.


assert TYPE == "one" or TYPE == "all" or TYPE == "balanced"

if TYPE == "all":

    with open('./double_uspto_mit_separated/src-test.txt', 'r') as fp:
        molbart_source_lines = fp.readlines()  

    with open('./double_uspto_mit_separated/tgt-test.txt', 'r') as fp:
        molbart_target_lines = fp.readlines()    

    df_test_pred = pd.read_pickle("./MOLBART_LAST_TIME/double_uspto_sep_out.pickle")

elif TYPE == "one":
    with open('./2xUSPTO-MIT_one_separated/src-test.txt', 'r') as fp:
        molbart_source_lines = fp.readlines()  

    with open('./2xUSPTO-MIT_one_separated/tgt-test.txt', 'r') as fp:
        molbart_target_lines = fp.readlines()    

    df_test_pred = pd.read_pickle("./chemformer_data/results/2xUSPTO-MIT_one_separated_out.pickle")

elif TYPE == "balanced":
    with open('./USPTO_balanced/src-test.txt', 'r') as fp:
        molbart_source_lines = fp.readlines()  

    with open('./USPTO_balanced/tgt-test.txt', 'r') as fp:
        molbart_target_lines = fp.readlines()    

    df_test_pred = pd.read_pickle("./chemformer_data/results/USPTO_balanced_out.pickle")
    
    
# convert pandas to a list of lines

molbart_pred_lines = []

for idx, row in df_test_pred.iterrows():
    molbart_pred_lines.append(row["prediction_0"])

assert len(molbart_pred_lines) == len(molbart_target_lines)

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = len(molbart_pred_lines)
total_valid_predictions = 0
total_accurate_predictions = 0
total_half_accurate_predictions = 0
c = 0
for source, target, prediction in tqdm.tqdm(zip(molbart_source_lines, molbart_target_lines, molbart_pred_lines), total=total_predictions):
    c += 1
    
    # if type is balanced, then the version is combined
    if TYPE == "balanced":
        source += " > "
    
    source_splits = source.split(">")
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    reactant = Chem.MolFromSmiles(source_splits[0].translate(remove_whitespaces))
    reagent = Chem.MolFromSmiles(source_splits[1].translate(remove_whitespaces))
    target = Chem.MolFromSmiles(target.translate(remove_whitespaces))
    prediction = Chem.MolFromSmiles(prediction.translate(remove_whitespaces))
    
    if prediction is not None and target is not None:
        total_valid_predictions += 1
        
    else:
        continue
        
    # compute accuracy
    try:
        target_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(target))
        prediction_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(prediction))
        
    
    except:
        print(Chem.MolToSmiles(target))
        print(Chem.MolToSmiles(prediction))
        
        target_canonsmiles = Chem.MolToSmiles(target)
        prediction_canonsmiles = Chem.MolToSmiles(prediction)
        
        
    if target_canonsmiles == prediction_canonsmiles:
        total_accurate_predictions += 1
    
    if is_half_accurate(prediction_canonsmiles, target_canonsmiles):
        total_half_accurate_predictions += 1
        
    reactants_counter, reactants_n_atoms = atom_count([reactant])
    prediction_counter, prediction_n_atoms = atom_count([prediction])
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, [prediction_counter])
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Valid predictions:", total_valid_predictions)
print("- Accurate predictions:", total_accurate_predictions)
print("- Half accurate predictions:", total_half_accurate_predictions)

accuracy = total_accurate_predictions / total_predictions
half_accuracy = total_half_accurate_predictions / total_predictions
valid_ratio = total_valid_predictions / total_predictions

try:
    balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
except ZeroDivisionError:
    balanced_ratio = 0.0
try:
    too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
except ZeroDivisionError:
    too_few_ratio = 0.0
try:
    too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
except ZeroDivisionError:
    too_many_ratio = 0.0
try:
    too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))
except ZeroDivisionError:
    too_few_many_ratio = 0.0
    
print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of accurate prediction is:", accuracy)
print("The ratio of half-accurate prediction is:", half_accuracy)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


##  Collect balance fine-tuned Chemformer predictions on double_mit_dataset and compute accuracy
### Chemformer is fine-tuned on both partial_balanced and fully_balanced dataset

In [None]:
def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

# check whether the model is able to predict at least 1x of products
def is_half_accurate(pred_canonsmiles, true_canonsmiles):
    true_smiles_split = true_canonsmiles.split(".")
    true_smiles_split_half = ".".join(set(true_smiles_split))
    true_canonsmiles_split_half = Chem.CanonSmiles(true_smiles_split_half)
    return true_canonsmiles_split_half == pred_canonsmiles

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

with open('./uspto_mit_separated_nowhitespace/src-test.txt', 'r') as fp:
    molbart_source_lines = fp.readlines()  

with open('./uspto_mit_separated_nowhitespace/tgt-test.txt', 'r') as fp:
    molbart_target_lines = fp.readlines()    

 

df_test_pred = pd.read_pickle("./Chemformer_finetuned/data/uspto_sep_out_partial_balance_fine_tuned.pickle")

# convert pandas to a list of lines

molbart_pred_lines = []

for idx, row in df_test_pred.iterrows():
    molbart_pred_lines.append(row["prediction_0"])

    
print("prediction lines:", len(molbart_pred_lines))
print("target_lines:", len(molbart_target_lines))
assert len(molbart_pred_lines) == len(molbart_target_lines)

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = len(molbart_pred_lines)
total_valid_predictions = 0
total_accurate_predictions = 0
total_half_accurate_predictions = 0
c = 0
for source, target, prediction in tqdm.tqdm(zip(molbart_source_lines, molbart_target_lines, molbart_pred_lines), total=total_predictions):
    c += 1
    
    source_splits = source.split(">")
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    reactant = Chem.MolFromSmiles(source_splits[0].translate(remove_whitespaces))
    #reagent = Chem.MolFromSmiles(source_splits[1].translate(remove_whitespaces))
    target = Chem.MolFromSmiles(target.translate(remove_whitespaces))
    
    prediction = prediction.translate(remove_whitespaces).split(",")[0]
    prediction = Chem.MolFromSmiles(prediction)

    
    if prediction is not None and target is not None:
        total_valid_predictions += 1
        if ADD_Hs:
            #reactant = Chem.AddHs(reactant)
            #reagent = Chem.AddHs(reagent)
            #prediction = Chem.AddHs(prediction)
            #target = Chem.AddHs(target)
            pass
    else:
        continue
        
    # compute accuracy
    target_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(target)))
    
    try:
        tmp = Chem.MolToSmiles(Chem.RemoveHs(prediction))
        prediction_canonsmiles = Chem.CanonSmiles(tmp)
    except:
        continue
    
    if target_canonsmiles == prediction_canonsmiles:
        total_accurate_predictions += 1
    
    if is_half_accurate(prediction_canonsmiles, target_canonsmiles):
        total_half_accurate_predictions += 1
        
    reactants_counter, reactants_n_atoms = atom_count([reactant])
    prediction_counter, prediction_n_atoms = atom_count([prediction])
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, [prediction_counter])
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Valid predictions:", total_valid_predictions)
print("- Accurate predictions:", total_accurate_predictions)
print("- Half accurate predictions:", total_half_accurate_predictions)

accuracy = total_accurate_predictions / total_predictions
half_accuracy = total_half_accurate_predictions / total_predictions
valid_ratio = total_valid_predictions / total_predictions

try:
    balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
except ZeroDivisionError:
    balanced_ratio = 0.0
try:
    too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
except ZeroDivisionError:
    too_few_ratio = 0.0
try:
    too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
except ZeroDivisionError:
    too_many_ratio = 0.0
try:
    too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))
except ZeroDivisionError:
    too_few_many_ratio = 0.0
    
print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of accurate prediction is:", accuracy)
print("The ratio of half-accurate prediction is:", half_accuracy)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


## Collect GRAPH2SMILES predictions on double_mit_dataset and compute accuracy

In [None]:
import string, tqdm
from collections import Counter
import rdkit
from rdkit import Chem

TYPE = "balanced"
ONLY_HEAVY_ATOMS = True

# shut up warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


def atom_count(molecules):
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([atom.GetSymbol() for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

def compute_alchemy_predictions(reactants_counter, predictions_counters):
    balanced_preds = [ (prediction_counter == reactants_counter) for prediction_counter in predictions_counters ]
    
    # bool returns False if counter is empty, True o.w.
    too_few_preds = [ bool((reactants_counter - prediction_counter)) for prediction_counter in predictions_counters ]
    too_many_preds = [ bool((prediction_counter - reactants_counter)) for prediction_counter in predictions_counters ]
    
    too_few_many_preds = [ (x and y) for x, y in zip(too_few_preds, too_many_preds) ]
    
    return balanced_preds, too_few_preds, too_many_preds, too_few_many_preds

# check whether the model is able to predict at least 1x of products
def is_half_accurate(pred_canonsmiles, true_canonsmiles):
    true_smiles_split = true_canonsmiles.split(".")
    true_smiles_split_half = ".".join(set(true_smiles_split))
    true_canonsmiles_split_half = Chem.CanonSmiles(true_smiles_split_half)
    return true_canonsmiles_split_half == pred_canonsmiles

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.


assert TYPE == "all" or TYPE == "one" or TYPE == "balanced"

if TYPE == "all":
    with open('./double_uspto_mit_separated/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./double_uspto_mit_separated/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./GRAPH2SMILES_predictions/results/double_uspto_mit_combined_g2s_series_rel_smiles_smiles.final_run_dgat.result.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()
elif TYPE == "one":
    with open('./2xUSPTO-MIT_one_separated/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./2xUSPTO-MIT_one_separated/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./GRAPH2SMILES_predictions/results/2xUSPTO-MIT_one_combined_g2s_series_rel_smiles_smiles.2xONE_dgcn.result.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()

elif TYPE == "balanced":
    with open('./USPTO_balanced/src-test.txt', 'r') as fp:
        moltrans_source_lines = fp.readlines()  

    with open('./USPTO_balanced/tgt-test.txt', 'r') as fp:
        moltrans_target_lines = fp.readlines()    

    with open('./GRAPH2SMILES_predictions/results/USPTO_balanced_g2s_series_rel_smiles_smiles.BALANCED_dgat.result.txt', 'r') as fp:
        moltrans_pred_lines = fp.readlines()
    
        
assert len(moltrans_pred_lines) == len(moltrans_target_lines)

total_balanced_counter = Counter()
total_too_few_counter = Counter()
total_too_many_counter = Counter()
total_too_few_many_counter = Counter()

total_predictions = len(moltrans_pred_lines)
total_valid_predictions = 0
total_accurate_predictions = 0
total_half_accurate_predictions = 0
c = 0
for source, target, prediction in tqdm.tqdm(zip(moltrans_source_lines, moltrans_target_lines, moltrans_pred_lines), total=total_predictions):
    c += 1
    
    if TYPE == "balanced":
        source += " > "
    
    source_splits = source.split(">")
    
    # preprocessing: removes whitespaces, converto to Mol, remove un-valid mol, add Hs.
    reactant = Chem.MolFromSmiles(source_splits[0].translate(remove_whitespaces))
    reagent = Chem.MolFromSmiles(source_splits[1].translate(remove_whitespaces))
    target = Chem.MolFromSmiles(target.translate(remove_whitespaces))
    
    prediction = prediction.split(",")[0]
    prediction = Chem.MolFromSmiles(prediction.translate(remove_whitespaces))

    if prediction is not None and target is not None:
        total_valid_predictions += 1
        
        if ONLY_HEAVY_ATOMS:
            reactant = Chem.RemoveHs(reactant)
            reagent = Chem.RemoveHs(reagent)
            prediction = Chem.RemoveHs(prediction)
            target = Chem.RemoveHs(target)
        
        else:
            reactant = Chem.AddHs(reactant)
            reagent = Chem.AddHs(reagent)
            prediction = Chem.AddHs(prediction)
            target = Chem.AddHs(target)
    else:
        continue
        
    # compute top-1 accuracy
    target_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(target))
    prediction_canonsmiles = Chem.CanonSmiles(Chem.MolToSmiles(prediction))
    
    if target_canonsmiles == prediction_canonsmiles:
        total_accurate_predictions += 1
    
    if is_half_accurate(prediction_canonsmiles, target_canonsmiles):
        total_half_accurate_predictions += 1
    
    reactants_counter, reactants_n_atoms = atom_count([reactant])
    prediction_counter, prediction_n_atoms = atom_count([prediction])
    
    balanced_preds, too_few_preds, too_many_preds, too_few_many_preds = compute_alchemy_predictions(reactants_counter, [prediction_counter])
    
    total_balanced_counter.update(balanced_preds)
    total_too_few_counter.update(too_few_preds)
    total_too_many_counter.update(too_many_preds)
    total_too_few_many_counter.update(too_few_many_preds)
    
    #if c > 100:
    #    break

print("- Total predictions:", total_predictions)
print("- Valid predictions:", total_valid_predictions)
print("- Accurate predictions:", total_accurate_predictions)
print("- Half accurate predictions:", total_half_accurate_predictions)

accuracy = total_accurate_predictions / total_predictions
half_accuracy = total_half_accurate_predictions / total_predictions
valid_ratio = total_valid_predictions / total_predictions
balanced_ratio = total_balanced_counter[True] / total_valid_predictions #len(list(total_balanced_counter.elements()))
too_few_ratio   = total_too_few_counter[True] / total_valid_predictions #len(list(total_too_few_counter.elements()))
too_many_ratio  = total_too_many_counter[True] / total_valid_predictions #len(list(total_too_many_counter.elements()))
too_few_many_ratio  = total_too_few_many_counter[True] / total_valid_predictions #len(list(total_too_few_many_counter.elements()))

print("The ratio of valid predictions is:", valid_ratio)
print("The ratio of accurate prediction is:", accuracy)
print("The ratio of half-accurate prediction is:", half_accuracy)
print("The ratio of balanced predictions is:", balanced_ratio)
print("The ratio of too-few predictions is:", too_few_ratio)
print("The ratio of too-many predictions is:", too_many_ratio)
print("The ratio of too-few-too-many predictions is:", too_few_many_ratio)


# Create new dataset where only one reactant/product is doubled

### Separated

In [None]:
import os
import random

with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_reacts = []
doubled_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    react_line_list = list(react_line)
    react_line_list[-1] = " "
    react_line = "".join(react_line_list)
    
    prod_line_list = list(prod_line)
    prod_line_list[-1] = " "
    prod_line = "".join(prod_line_list)
    
    react_splits = react_line.split(">")
    
    
    assert len(react_splits) == 2
    
    single_reactants = react_splits[0].split(".")
    #single_reagents  = react_splits[1].split(".")
    single_products  = prod_line.split(".")
    
    
    all_molecules = single_reactants + single_products
    
    # choose wether to sample from products or reactants
    if random.randint(0, 1) == 0:
        #print("double REACT")
        mol_to_double = random.choice(single_reactants)
        
    else:
        #print("double PROD")
        mol_to_double = random.choice(single_products)

    #print(mol_to_double)
    single_reactants.append(mol_to_double)
    single_products.append(mol_to_double)
    
    random.shuffle(single_reactants)
    random.shuffle(single_products)
    
    reactants = " . ".join(single_reactants)
    reagents  = react_splits[1]
    products  = " . ".join(single_products)
    
    src_str  = reactants + " > " + reagents + "\n"
    prod_str = products + "\n"
    
    for _ in range(5):
        src_str = src_str.replace("  ", " ")
        prod_str = prod_str.replace("  ", " ")

    src_str = src_str.replace(" \n", "\n")
    src_str = src_str.replace("> .\n", ">\n")
    prod_str = prod_str.replace(" \n", "\n")
    
    
    doubled_reacts.append(src_str)
    doubled_prods.append(prod_str)

if not os.path.exists("./2xUSPTO-MIT_one_separated/"):
    os.makedirs("./2xUSPTO-MIT_one_separated/")

with open('./2xUSPTO-MIT_one_separated/src-test.txt', 'w') as fp_src:
    with open('./2xUSPTO-MIT_one_separated/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_reacts)
        fp_tgt.writelines(doubled_prods)

print("DONE")

### Combined

In [None]:
import os
import random

with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_reacts = []
doubled_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    react_line_list = list(react_line)
    react_line_list[-1] = " "
    react_line = "".join(react_line_list)
    
    prod_line_list = list(prod_line)
    prod_line_list[-1] = " "
    prod_line = "".join(prod_line_list)
    
    react_splits = react_line.split(">")
    
    
    assert len(react_splits) == 2
    
    single_reactants = react_splits[0].split(".")
    single_reagents  = react_splits[1].split(".")
    single_products  = prod_line.split(".")
    
    
    all_molecules = single_reactants + single_products
    
    # choose wether to sample from products or reactants
    if random.randint(0, 1) == 0:
        #print("double REACT")
        mol_to_double = random.choice(single_reactants)
        
    else:
        #print("double PROD")
        mol_to_double = random.choice(single_products)

    #print(mol_to_double)
    single_reactants.append(mol_to_double)
    single_products.append(mol_to_double)
    
    single_reactants += single_reagents
    
    random.shuffle(single_reactants)
    random.shuffle(single_products)
    
    reactants = " . ".join(single_reactants)
    products  = " . ".join(single_products)
    
    src_str  = reactants + "\n"
    prod_str = products + "\n"
    
    for _ in range(5):
        src_str = src_str.replace("  ", " ")
        prod_str = prod_str.replace("  ", " ")

    src_str = src_str.replace(" \n", "\n")
    
    src_str = src_str.replace("> .\n", ">\n")
    prod_str = prod_str.replace(" \n", "\n")

    src_str = src_str.replace(". .", ".")
    prod_str = prod_str.replace(". .", ".")

    src_str = src_str.replace(".\n", "\n")
    prod_str = prod_str.replace(".\n", "\n")
    
    src_str = src_str.replace(". \n", "\n")
    prod_str = prod_str.replace(". \n", "\n")

    
    if src_str.startswith(" . "):
        src_str = src_str[3:]
    
    if prod_str.startswith(" . "):
        prod_str = prod.str[3:]
    
    
    doubled_reacts.append(src_str)
    doubled_prods.append(prod_str)

if not os.path.exists("./2xUSPTO-MIT_one_combined/"):
    os.makedirs("./2xUSPTO-MIT_one_combined/")

with open('./2xUSPTO-MIT_one_combined/src-test.txt', 'w') as fp_src:
    with open('./2xUSPTO-MIT_one_combined/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_reacts)
        fp_tgt.writelines(doubled_prods)

print("DONE")

# Remove whitespaces from MIT_separated

In [None]:
import os

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

no_whitespaced_reacts = []
no_whitespaced_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    no_whitespaced_reacts.append(react_line.translate(remove_whitespaces) + "\n")
    no_whitespaced_prods.append(prod_line.translate(remove_whitespaces) + "\n")


if not os.path.exists("./uspto_mit_separated_nowhitespace/"):
    os.makedirs("./uspto_mit_separated_nowhitespace/")

with open('./uspto_mit_separated_nowhitespace/src-test.txt', 'w') as fp_src:
    with open('./uspto_mit_separated_nowhitespace/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(no_whitespaced_reacts)
        fp_tgt.writelines(no_whitespaced_prods)

print("DONE")

# Create doubled dataset without whitespaces

In [None]:
import os

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

with open('./data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r') as fp_src:
    with open('./data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_reacts = []
doubled_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    react_splits = react_line[:-1].split(">")
    
    assert len(react_splits) == 2
    
    reactants = react_splits[0].translate(remove_whitespaces) + "." + react_splits[0].translate(remove_whitespaces)
    reagents = react_splits[1].translate(remove_whitespaces) + "." + react_splits[1].translate(remove_whitespaces) + "\n"
    products = prod_line[:-1].translate(remove_whitespaces) + "." + prod_line[:-1].translate(remove_whitespaces) + "\n"
    
    src_str  = reactants + ">" + reagents
    prod_str = products
    
    doubled_reacts.append(src_str)
    doubled_prods.append(prod_str)
    
if not os.path.exists("./double_uspto_mit_separated_nowhitespace/"):
    os.makedirs("./double_uspto_mit_separated_nowhitespace/")

with open('./double_uspto_mit_separated_nowhitespace/src-test.txt', 'w') as fp_src:
    with open('./double_uspto_mit_separated_nowhitespace/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_reacts)
        fp_tgt.writelines(doubled_prods)

print("DONE")

# Create double_t2 dataset without whitespaces

In [None]:
import os, string

remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.

with open('./2xUSPTO-MIT_one_separated/src-test.txt', 'r') as fp_src:
    with open('./2xUSPTO-MIT_one_separated/tgt-test.txt', 'r') as fp_tgt:
        reactants_lines = fp_src.readlines()
        products_lines = fp_tgt.readlines()

assert len(reactants_lines) == len(products_lines)

doubled_t2_reacts = []
doubled_t2_prods  = []

for react_line, prod_line in zip(reactants_lines, products_lines):
    
    react_line = react_line.translate(remove_whitespaces)
    prod_line = prod_line.translate(remove_whitespaces)
    
    src_str  = react_line + "\n"
    prod_str = prod_line  + "\n"
    
    doubled_t2_reacts.append(src_str)
    doubled_t2_prods.append(prod_str)
    
if not os.path.exists('./2xUSPTO-MIT_one_separated_nowhitespace/'):
    os.makedirs('./2xUSPTO-MIT_one_separated_nowhitespace/')

with open('./2xUSPTO-MIT_one_separated_nowhitespace/src-test.txt', 'w') as fp_src:
    with open('./2xUSPTO-MIT_one_separated_nowhitespace/tgt-test.txt', 'w') as fp_tgt:
        fp_src.writelines(doubled_t2_reacts)
        fp_tgt.writelines(doubled_t2_prods)

print("DONE")