# Build new balanced dataset

In [None]:
import string
import tqdm, os

from rdkit import Chem
from collections import Counter

import pandas as pd

# shut up warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

ADD_Hs = True
PARTIAL_BALANCE = False

## New dataset of balanced reactions
### helper functions

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdDepictor
rdDepictor.SetPreferCoordGen(True)
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG

# actual function to draw molecule
def drawMyMol(mol, fname, image_size, myFontSize):
    d = rdMolDraw2D.MolDraw2DCairo(image_size[0], image_size[1])
    d.SetFontSize(myFontSize)
    #print(d.FontSize())
    d.DrawMolecule(mol)
    d.FinishDrawing()
    d.WriteDrawingText(fname)

# draw and store the reactions of interest
def draw_and_store_reactions(c, dirpath, reactants, reagents, products, pred):
    
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)
    
    reactants_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(react)) for react in reactants ])
    reactants_mol = Chem.MolFromSmiles(reactants_smiles)

    reagents_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(reag)) for reag in reagents ])
    reagents_mol = Chem.MolFromSmiles(reagents_smiles)

    products_smiles = ".".join([ Chem.MolToSmiles(Chem.RemoveHs(prod)) for prod in products ])
    products_mol = Chem.MolFromSmiles(products_smiles)

    if not pred:
        pred_smiles = ""
    else:
        pred_smiles = Chem.MolToSmiles(pred)

    clear_output(wait=True)

    reacts_imgs = []

    IMAGE_sz = 500
    MPL_sz = 200
    FONTSIZE = 40
    
    keep = "y"
    if keep == "y":
        dirpath = "./recovered_reactions_plots/" + str(c) + "/"
        os.makedirs(dirpath, exist_ok=True)

        #for i, img in enumerate(reacts_imgs):
        #    img.save(os.path.join(dirpath, "react_" + str(i) + ".png"))

        for i, react in enumerate(reactants):
            drawMyMol(react,  os.path.join(dirpath, "react_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)

        for i, reag in enumerate(reagents):
            drawMyMol(reag,  os.path.join(dirpath, "reag_" + str(i) +".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)

        if pred is not None:
            drawMyMol(pred,  os.path.join(dirpath, "pred_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz), myFontSize=FONTSIZE)

        for i, prod in enumerate(products):
            drawMyMol(prod,  os.path.join(dirpath, "ground_truth_" + str(i) + ".png"), image_size=(IMAGE_sz, IMAGE_sz),  myFontSize=FONTSIZE)

        #reag_img.save(os.path.join(dirpath,"reag.png"))
        #pred_img.save(os.path.join(dirpath,"pred.png"))
        #gt_img.save(os.path.join(dirpath, "ground_truth.png"))

        with open(os.path.join(dirpath, "smiles.txt"), "w") as fp:
            fp.write(
                "REACTANTS   :" + reactants_smiles + "\n" + 
                "REAGENTS    :" + reagents_smiles + "\n" +
                "PREDICTION  :" + pred_smiles + "\n" +
                "GROUND TRUTH:" + products_smiles + "\n"
            )


# updated version of atom_count, consider also formal charge into the atom symbol
def atom_count(molecules):
    def get_symbol_with_charge(atom):
        symbol = atom.GetSymbol()
        charge = atom.GetFormalCharge()
        
        if charge == 1:
            symbol += "+"
        elif charge == -1:
            symbol += "-"
        elif charge > 1:
            symbol += "+" + str(charge)
        elif charge < -1:
            symbol += "-" + str(-charge)
            
        return symbol
        
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([get_symbol_with_charge(atom) for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms


def make_balanced(idx, reactants, products, plot=False):
    reactants_counter, reactants_n_atoms = atom_count(reactants)
    products_counter, products_n_atoms = atom_count(products)
    
    # reaction is already balanced
    if reactants_counter == products_counter:
        return True, False, reactants, products
    
    # if not already balanced, we need to rebelanced either reactants and products (or both)
    remaining_reactants_counter = reactants_counter - products_counter
    remaining_products_counter = products_counter - reactants_counter
    
    # add a new molecule with remaining reactants' atoms to the products
    if remaining_reactants_counter:
        init_smiles = "[" + "][".join(list(remaining_reactants_counter.elements())) + "]"
        try:
            canon_smiles = Chem.CanonSmiles(init_smiles)
            remaining_mol = Chem.AddHs(Chem.MolFromSmiles(canon_smiles))            
            remaining_counter_after, _ = atom_count([remaining_mol])
            new_products = products + [remaining_mol]            
        except:
            return False, False, reactants, products
    else:
        new_products = products

    # add a new molecule with remaining products' atoms to the reactants
    if remaining_products_counter:
        init_smiles = "[" + "][".join(list(remaining_products_counter.elements())) + "]"
        try:
            canon_smiles = Chem.CanonSmiles(init_smiles)
            remaining_mol = Chem.AddHs(Chem.MolFromSmiles(canon_smiles))            
            remaining_counter_after, _ = atom_count([remaining_mol])
            new_reactants = reactants + [remaining_mol]
        except:
            return False, False, reactants, products
    else:
        new_reactants = reactants

    if plot:
        draw_and_store_reactions(idx, "./recovered_reactions_plots", new_reactants, reagents, new_products, None)
        
    return True, True, new_reactants, new_products



 ### Populate balanced dataset

In [None]:


# populate balanced dataset
balanced_dataset = {
    "TR": {
        "fp_src": open('./raw_data/DataSet-USPTO-main/MIT_separated/src-train.txt', 'r'),
        "fp_tgt": open('./raw_data/DataSet-USPTO-main/MIT_separated/tgt-train.txt', 'r'),
        "reagents" : [],
        "reactants": [],
        "products" : [],
        "balanced_count": 0,
        "recovered_count": 0
    },
    "VL":{
        "fp_src": open('./raw_data/DataSet-USPTO-main/MIT_separated/src-valid.txt', 'r'),
        "fp_tgt": open('./raw_data/DataSet-USPTO-main/MIT_separated/tgt-valid.txt', 'r'),        
        "reagents" : [],
        "reactants": [],
        "products" : [],
        "balanced_count": 0,
        "recovered_count": 0        
    },
    "TS":{
        "fp_src": open('./raw_data/DataSet-USPTO-main/MIT_separated/src-test.txt', 'r'),
        "fp_tgt": open('./raw_data/DataSet-USPTO-main/MIT_separated/tgt-test.txt', 'r'),        
        "reagents" : [],
        "reactants": [],
        "products" : [],
        "balanced_count": 0,
        "recovered_count": 0        
    }
}

# load data in memory
for data_split in ["TR", "VL", "TS"]:
    src_lines = balanced_dataset[data_split]["fp_src"].readlines()
    tgt_lines = balanced_dataset[data_split]["fp_tgt"].readlines()
    balanced_dataset[data_split]["fp_src"].close()
    balanced_dataset[data_split]["fp_tgt"].close()
    del balanced_dataset[data_split]["fp_src"]
    del balanced_dataset[data_split]["fp_tgt"]
    
    remove_whitespaces = str.maketrans('', '', string.whitespace) # used for removing whitespaces.
    c = 0
    for src_line, tgt_line in tqdm.tqdm(zip(src_lines, tgt_lines), total=len(src_lines)):
     
        c += 1
        #if c > 100:
        #     break
    
        src_line_splits = src_line.split(">")
        assert(len(src_line_splits) == 2)
        reactants_nonsplit, reagents_nonsplit = src_line_splits[0].translate(remove_whitespaces), src_line_splits[1].translate(remove_whitespaces)

        products_nonsplit = tgt_line.translate(remove_whitespaces)

        reactants = reactants_nonsplit.split(".")
        reagents  = reagents_nonsplit.split(".")
        products  = products_nonsplit.split(".")
        
        # sometimes reagents are empty, remove empty strings
        if reagents == [""]:
            #print("reagents are empty string!")
            reagents = []
            
        if ADD_Hs:
            reactants = [Chem.AddHs(Chem.MolFromSmiles(reactant)) for reactant in reactants ]
            reagents = [Chem.AddHs(Chem.MolFromSmiles(reagent)) for reagent in reagents ]
            products = [Chem.AddHs(Chem.MolFromSmiles(product)) for product in products ]
        else:
            reactants = [Chem.MolFromSmiles(reactant) for reactant in reactants ]
            reagents = [Chem.MolFromSmiles(reagent) for reagent in reagents ]
            products = [Chem.MolFromSmiles(product) for product in products ]
    
        reactants_counter, reactants_n_atoms = atom_count(reactants)
        products_counter, products_n_atoms = atom_count(products)
        
        if c < 100:
            plot = True
        else:
            plot = False
            break
            
        balanced, recovered, out_reactants, out_products = make_balanced(c, reactants, products, plot)
        
        # nothing that can be done, go to next reaction
        if not balanced:
            # if partial balance, add the unbalanced reaction anyway
            if PARTIAL_BALANCE:
                balanced_dataset[data_split]["reactants"].append(out_reactants)
                balanced_dataset[data_split]["reagents"].append(reagents)
                balanced_dataset[data_split]["products"].append(out_products)
            continue
        else:
            balanced_dataset[data_split]["balanced_count"] += 1
            if recovered:
                balanced_dataset[data_split]["recovered_count"] += 1
                
        # final double check that the reaction is now balanced
        out_reactants_counter, _ = atom_count(out_reactants)
        out_products_counter, _ = atom_count(out_products)
        
        # reaction is now balanced for sure
        if out_reactants_counter == out_products_counter:
            balanced_dataset[data_split]["reactants"].append(out_reactants)
            balanced_dataset[data_split]["reagents"].append(reagents)
            balanced_dataset[data_split]["products"].append(out_products)
            
        else:
            print("Rebalancing did not work!")
            assert False
       
    assert len(balanced_dataset[data_split]["reactants"]) == len(balanced_dataset[data_split]["reagents"]) == len(balanced_dataset[data_split]["products"])
    balanced_dataset[data_split]["len"] = len(balanced_dataset[data_split]["reactants"])
    
print("- Balanced dataset completed!")
print("- TR reactions:", balanced_dataset["TR"]["balanced_count"], "of which", balanced_dataset["TR"]["recovered_count"], "have been recovered.")
print("- VL reactions:", balanced_dataset["VL"]["balanced_count"], "of which", balanced_dataset["VL"]["recovered_count"], "have been recovered.")
print("- TS reactions:", balanced_dataset["TS"]["balanced_count"], "of which", balanced_dataset["TS"]["recovered_count"], "have been recovered.")

print("- TR recovered ratio:", balanced_dataset["TR"]["recovered_count"] / balanced_dataset["TR"]["balanced_count"])
print("- VL recovered ratio:", balanced_dataset["VL"]["recovered_count"] / balanced_dataset["VL"]["balanced_count"])
print("- TS recovered ratio:", balanced_dataset["TS"]["recovered_count"] / balanced_dataset["TS"]["balanced_count"])


total_dataset_len = balanced_dataset["TR"]["len"] + balanced_dataset["VL"]["len"] + balanced_dataset["TS"]["len"]

tr_ratio = balanced_dataset["TR"]["len"] / total_dataset_len
vl_ratio = balanced_dataset["VL"]["len"] / total_dataset_len
ts_ratio = balanced_dataset["TS"]["len"] / total_dataset_len

print("TOT len:", total_dataset_len)
print("TR len:", balanced_dataset["TR"]["len"])
print("VL len:", balanced_dataset["VL"]["len"])
print("TS len:", balanced_dataset["TS"]["len"])

print("Balanced dataset tr/vl/ts split is:", round(tr_ratio, 2), "/", round(vl_ratio, 2), "/", round(ts_ratio, 2))


In [None]:
import pandas as pd

# LAST STEP: store the (partially) balanced dataset:
# store the data
if PARTIAL_BALANCE:
    dataset_name = "USPTO_partially_balanced"
else:
    dataset_name = "USPTO_balanced"
    
if not os.path.exists("./" + dataset_name):
    os.makedirs("./" + dataset_name)

for data_split, split_name in [["TR", "train"], ["VL", "valid"],["TS", "test"]]:
    with open("./" + dataset_name + "/src-" + split_name + ".txt", "w") as src_fp:
        with open("./" + dataset_name + "/tgt-" + split_name + ".txt", "w") as dst_fp:

            reactants_data = balanced_dataset[data_split]["reactants"]
            reagents_data  = balanced_dataset[data_split]["reagents"]
            products_data  = balanced_dataset[data_split]["products"]
            
            print("Storing", data_split, "set..")
            for reactants, reagents, products in tqdm.tqdm(zip(reactants_data, reagents_data, products_data), total=len(reactants_data)):
                
                # reomve hs only in this case, where i want to compare with vanilla USPTO
                if PARTIAL_BALANCE:
                    reactants_smiles = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in reactants ]
                    reagents_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in reagents ]
                    products_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in products ]
                else:
                    reactants_smiles = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in reactants ]
                    reagents_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in reagents ]
                    products_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in products ]
                
                reactants_line = ".".join(reactants_smiles)
                reagents_line = ".".join(reagents_smiles)
                products_line = ".".join(products_smiles)
                
                if reagents_line != "":
                    src_line = reactants_line + "." + reagents_line + "\n"
                    tgt_line = products_line + "." + reagents_line + "\n"
                else:
                    src_line = reactants_line + "\n"
                    tgt_line = products_line + "\n"
                    
                src_fp.write(src_line)
                dst_fp.write(tgt_line)

## store data in pickle format for training chemformer

In [None]:
def convert_dataset_to_pickle(dataset, add_hs=False):
    data = pd.DataFrame(columns=["reactants_mol", "products_mol", "reagents_mol", "set"])
    
    data_split_to_set = {
        "TR" : "train",
        "VL" : "valid",
        "TS" : "test"
    }
    
    for data_split in ["TR", "VL", "TS"]:    
        current_data_split = dataset[data_split]
        c = 0
        
        for reactants, reagents, products in tqdm.tqdm(zip(current_data_split["reactants"], current_data_split["reagents"], current_data_split["products"]), total=current_data_split["len"]):
        
            if c > 100:
                break
        
            # remove hs only in this case, where i want to compare with vanilla USPTO
            if not add_hs:
                reactants_smiles = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in reactants ]
                reagents_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in reagents ]
                products_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.RemoveHs(mol))) for mol in products ]
            else:
                reactants_smiles = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in reactants ]
                reagents_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in reagents ]
                products_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(mol)) for mol in products ]

            reactants_line = ".".join(reactants_smiles)
            reagents_line = ".".join(reagents_smiles)
            products_line = ".".join(products_smiles)

            if add_hs:
                reactants = Chem.AddHs(Chem.MolFromSmiles(reactants_line))
                reagents = Chem.AddHs(Chem.MolFromSmiles(reagents_line))
                products = Chem.AddHs(Chem.MolFromSmiles(products_line))
            else:
                reactants = Chem.RemoveHs(Chem.MolFromSmiles(reactants_line))
                reagents = Chem.RemoveHs(Chem.MolFromSmiles(reagents_line))
                products = Chem.RemoveHs(Chem.MolFromSmiles(products_line))

            data.loc[len(data.index)] = [reactants, reagents, products, data_split_to_set[data_split] ]    
            
    return data

# store data in pickle format for training chemformer
if PARTIAL_BALANCE:
    subdir = "USPTO_partially_balanced"
    add_hs = False
    store_name = "uspto_partially_balanced_prot3.pickle"
else:
    subdir = "USPTO_balanced"
    add_hs = True
    store_name = "uspto_balanced_prot3.pickle"
    
if not os.path.exists("./pickle_datasets/" + subdir + "/"):
    os.makedirs("./pickle_datasets/" + subdir + "/")    

print("Converting dataset to pickle...")
pickle_data = convert_dataset_to_pickle(balanced_dataset, add_hs)
pickle_data.to_pickle("./pickle_datasets/" + subdir + "/" + store_name, protocol=3)


# Add whitespace to dataset for Molecular Transformer

In [None]:
import os, tqdm

def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]|\{[0-9]+\})"
    regex = re.compile(pattern)
    
    smi_lines = smi.split("\n")
    
    output = ""
    for smi_line in smi_lines:
        tokens = [token for token in regex.findall(smi_line)]
            
        try:
            assert smi_line == ''.join(tokens)
        except AssertionError:
            print("SMI_LINE:", smi_line)
            print("TOKENS:", tokens)
            
            remaining = regex.sub("", smi_line)
            
            print("REMAINING:", remaining)
            
            assert False
            
        output += ' '.join(tokens) + "\n"
        
    return output


## MAIN ##

if not os.path.exists("./USPTO_balanced_tokenized/"):
    os.makedirs("./USPTO_balanced_tokenized/")

    
_, _, filenames = next(os.walk("./USPTO_balanced/"))

for filename in filenames:
    if not filename.endswith(".txt"):
        continue

    #print("--", filename)

    with open("./USPTO_balanced/" + filename, "r") as fp:
        smi = fp.read()
        tokens = smi_tokenizer(smi)

    with open("./USPTO_balanced_tokenized/" + filename, "w") as fp:
        fp.write(tokens)

print("DONE!")

# Create augmentations by changing stoichiometric coefficients
## Helper functions

In [None]:
import chempy
from chempy import balance_stoichiometry
import sympy, random, numpy as np
import copy
from rdkit.Chem.rdMolDescriptors import CalcMolFormula

def atom_count(molecules):
    def get_symbol_with_charge(atom):
        symbol = atom.GetSymbol()
        charge = atom.GetFormalCharge()
        
        if charge == 1:
            symbol += "+"
        elif charge == -1:
            symbol += "-"
        elif charge > 1:
            symbol += "+" + str(charge)
        elif charge < -1:
            symbol += "-" + str(-charge)
            
        return symbol
        
    counter, num_atoms = Counter(), 0
    for mol in molecules:
        counter += Counter([get_symbol_with_charge(atom) for atom in mol.GetAtoms()]) 
        num_atoms += mol.GetNumAtoms()
    
    return counter, num_atoms

# builds a single augmentation of a reaction
# params:
# reactants, reagents, products: are assumed to be SMILES strings
# base_k_set, add_k_set: set of values from which sample the stoichiometric coefficients
#                      | the base_k_set is for the base coefficient, add_k_set is for the coeff of the added products
def build_augmentation(reactants, reagents, products, aug_type, k_set, representation="smiles"):
    
    if representation not in {"formula", "smiles"}:
        raise ValueError("Unknown representation:", representation)
    
    if representation == "formula":
        reactants = [ CalcMolFormula(Chem.MolFromSmiles(r)) for r in reactants ]
        reagents  = [ CalcMolFormula(Chem.MolFromSmiles(r)) for r in reagents  ]
        products  = [ CalcMolFormula(Chem.MolFromSmiles(p)) for p in products  ]
    
    # if type 1 augmentation, then set all coeff for every molecule
    if aug_type == 1:
        base_k = random.choice(k_set)
        
        aug_reactants_list = [ "{" + str(base_k) + "}" + react for react in reactants + reagents ]
        aug_products_list  = [ "{" + str(base_k) + "}" + prod for prod in products + reagents ]
                
    # if type 2 augmentation, select a different coefficient for each molecule    
    elif aug_type == 2:
        coeffs = { mol: random.choice(k_set) for mol in reactants + reagents + products }
        k_min = random.randint(1, min(coeffs.values()))
        
        # first list: use sampled coeff.
        # second list: add on lhs the excess products
        aug_reactants_list = [
            "{" + str(coeffs[react]) + "}" + react
            for react in reactants + reagents
        ] + [
            "{" + str(coeffs[prod] - k_min) + "}" + prod 
            for prod in products if coeffs[prod] - k_min > 0
        ]
        
        # this is specular to the list comprehension above, just for products
        aug_products_list = [
            "{" + str(coeffs[prod]) + "}" + prod
            for prod in products + reagents
        ] + [
            "{" + str(coeffs[react] - k_min) + "}" + react 
            for react in reactants if coeffs[react] - k_min > 0
        ]
          
    else:
        raise ValueError("Unknown aug_type:", aug_type)

    random.shuffle(aug_reactants_list)
    random.shuffle(aug_products_list)

    aug_reactants = ".".join(aug_reactants_list)
    aug_products  = ".".join(aug_products_list)
    
    return aug_reactants, aug_products
    

## Make augmentations

In [None]:

def build_augmented_data_split(dataset_name, data_split, split_filename, n_augmentations, aug_type, k_set, representation):
    path_prefix = "./chemalebra_datasets/"
    if not os.path.exists(path_prefix + dataset_name):
        os.makedirs(path_prefix + dataset_name)

    with open(path_prefix + dataset_name + "/src-" + split_filename + ".txt", "w") as src_fp:
        with open(path_prefix + dataset_name + "/tgt-" + split_filename + ".txt", "w") as dst_fp:

            reactants_data = balanced_dataset[data_split]["reactants"]
            reagents_data  = balanced_dataset[data_split]["reagents"]
            products_data  = balanced_dataset[data_split]["products"]

            for reactants, reagents, products in tqdm.tqdm(zip(reactants_data, reagents_data, products_data), total=len(reactants_data)):

                reactants_smiles = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.AddHs(mol))) for mol in reactants ]
                reagents_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.AddHs(mol))) for mol in reagents ]
                products_smiles  = [ Chem.CanonSmiles(Chem.MolToSmiles(Chem.AddHs(mol))) for mol in products ]

                for _ in range(n_augmentations):
                    aug_src, aug_tgt = build_augmentation(
                        reactants_smiles, reagents_smiles, products_smiles,
                        aug_type=aug_type,
                        k_set=k_set,
                        representation=representation
                    )
                    
                    aug_src += "\n"
                    aug_tgt += "\n"

                    src_fp.write(aug_src)
                    dst_fp.write(aug_tgt)    


# second part: save the dataset on disk
print("- Building starting dataset...")
for data_split, split_filename in [["TR", "train"], ["VL", "valid"],["TS", "test"]]:
    build_augmented_data_split(
        "USPTO_augmented_start", data_split, split_filename,
        n_augmentations=1, aug_type=1, k_set=[1], representation="smiles"
    )
    
# TYPE 1 AUGMENTATIONS
dataset_prefix = "USPTO_augmented_type1"

print("Building augmentations type 1...")
for n_augm in [1, 5, 10]:
    for rep in ["smiles", "formula"]:
        print("- n_augmentations:", n_augm, ", rep:", rep)
        dataset_name = dataset_prefix + "_x" + str(n_augm) + "_" + rep
        
        build_augmented_data_split(
            dataset_name, data_split="TR", split_filename="train",
            n_augmentations=n_augm, aug_type=1, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="VL", split_filename="valid",
            n_augmentations=n_augm, aug_type=1, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TS", split_filename="test_in",
            n_augmentations=n_augm, aug_type=1, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TR", split_filename="test_x",
            n_augmentations=n_augm, aug_type=1, k_set=[6, 7, 8, 9, 10], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TS", split_filename="test_out",
            n_augmentations=n_augm, aug_type=1, k_set=[6, 7, 8, 9, 10], representation=rep
        )

# TYPE 2 AUGMENTATIONS
dataset_prefix = "USPTO_augmented_type2"

print("Building augmentations type 2...")
for n_augm in [1, 5, 10]:
    for rep in ["smiles", "formula"]:
        print("- n_augmentations:", n_augm, ", rep:", rep)        
        dataset_name = dataset_prefix + "_x" + str(n_augm) + "_" + rep
        
        build_augmented_data_split(
            dataset_name, data_split="TR", split_filename="train",
            n_augmentations=n_augm, aug_type=2, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="VL", split_filename="valid",
            n_augmentations=n_augm, aug_type=2, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TS", split_filename="test_in",
            n_augmentations=n_augm, aug_type=2, k_set=[1, 2, 3, 4, 5], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TR", split_filename="test_x",
            n_augmentations=n_augm, aug_type=2, k_set=[6, 7, 8, 9, 10], representation=rep
        )
        build_augmented_data_split(
            dataset_name, data_split="TS", split_filename="test_out",
            n_augmentations=n_augm, aug_type=2, k_set=[6, 7, 8, 9, 10], representation=rep
        )



# Tokenize datasets for Molecular Transformer

In [None]:
import os, tqdm

def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]|\{[0-9]+\})"
    regex = re.compile(pattern)
    
    smi_lines = smi.split("\n")
    
    output = ""
    for smi_line in smi_lines:
        tokens = [token for token in regex.findall(smi_line)]
            
        try:
            assert smi_line == ''.join(tokens)
        except AssertionError:
            print("SMI_LINE:", smi_line)
            print("TOKENS:", tokens)
            
            remaining = regex.sub("", smi_line)
            
            print("REMAINING:", remaining)
            
            assert False
            
        output += ' '.join(tokens) + "\n"
        
    return output

def smi_formula_tokenizer(smi):
    """
    Tokenize a molecule formula
    """
    import re
    pattern =  "([A-Z][a-z]?|\.|-[0-9]*|\+[0-9]*|[0-9]+|\{[0-9]+\})"
    regex = re.compile(pattern)
    
    smi_lines = smi.split("\n")
    
    output = ""
    for smi_line in smi_lines:
        tokens = [token for token in regex.findall(smi_line)]
        
        try:
            assert smi_line == ''.join(tokens)
        except AssertionError:
            print("SMI_LINE:", smi_line)
            print("TOKENS:", tokens)
            
            remaining = regex.sub("", smi_line)
            
            print("REMAINING:", remaining)
            
            assert False
            
        output += ' '.join(tokens) + "\n"
        
    return output


## MAIN ##

if not os.path.exists("./tokenized_datasets/"):
    os.makedirs("./tokenized_datasets/")

_, subdirs, _ = next(os.walk("./new_datasets/"))

for subdir in tqdm.tqdm(subdirs):   
    _, _, filenames = next(os.walk("./new_datasets/" + subdir))
    
    print("- Tokenizing", subdir, "...")
    
    if not os.path.exists("./tokenized_datasets/" + subdir + "/"):
        os.makedirs("./tokenized_datasets/" + subdir + "/")    
    
    for filename in filenames:
        if not filename.endswith(".txt"):
            continue
        
        #print("--", filename)
        
        with open("./new_datasets/" + subdir + "/" + filename, "r") as fp:
            smi = fp.read()
            
            if subdir.endswith("formula"):
                tokens = smi_formula_tokenizer(smi)
            else:
                tokens = smi_tokenizer(smi)
                    
        with open("./tokenized_datasets/" + subdir + "/" + filename, "w") as fp:
            fp.write(tokens)

print("DONE!")

# Compute metrics on models' predictions

In [None]:
import os, re, string, tqdm
from rdkit import Chem
from pprint import pprint
import numpy as np

# shut up warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


def accuracies_top_1(mol_matching_fn, true, pred):
    
    #print("_TRUE_:")
    #pprint(true)
    
    #print("_PRED_:")
    #pprint(pred)
    
    # matched list contains the details of the match of two k-mol
    matched_list = []
    k_mol_true_list = []
    k_mol_pred_list = []
    
    metrics_mol_only = { "tp": 0, "fp": 0, "fn": 0 }
    metrics_mol_andk = { "tp": 0, "fp": 0, "fn": 0 }
    
    for k_mol in true:
        k_mol_true_list.append({
            "matched": False,
            "k": k_mol[0],
            "mol": k_mol[1]
        })
        
    for k_mol in pred:
        k_mol_pred_list.append({
            "matched": False,
            "k": k_mol[0],
            "mol": k_mol[1]
        })
    
    #print("_TRUE_:")
    #pprint(k_mol_true_list)
    
    #print("_PRED_:")
    #pprint(k_mol_pred_list)
    
    for i, true_k_mol in enumerate(k_mol_true_list):
        k_true   = true_k_mol["k"]
        mol_true = true_k_mol["mol"]
        
        # try to match the true mols to a predicted mols
        for j, pred_k_mol in enumerate(k_mol_pred_list):
            
            k_pred   = pred_k_mol["k"]
            mol_pred = pred_k_mol["mol"]
            
            try:
                m = mol_matching_fn(mol_true, mol_pred)
            except:
                print("exception!")
                
                print("pred_k_mol:", pred_k_mol)
                print("true:", mol_true, "pred:", mol_pred)
                #print(Chem.CanonSmiles(mol_true))
                #print(Chem.CanonSmiles(mol_pred))
                assert False
            
            if m:
                if (not true_k_mol["matched"]) and (not pred_k_mol["matched"]):
                    true_k_mol["matched"] = True
                    pred_k_mol["matched"] = True
                    
                    # mol_only just cares about the molecule
                    # mol_andk also cares about stoichiometric coefficients
                    metrics_mol_only["tp"] += 1
                    
                    matched = {
                        "true_idx": i,
                        "pred_idx": j,
                        "matched_mol": True,
                        "matched_k": False
                    }
                    
                    # if both k_true and k_pred are not None, then try to
                    # do an exact match
                    if k_true and k_pred:
                        metrics_mol_andk["tp"] += min(k_true, k_pred)

                        if k_true > k_pred:
                            metrics_mol_andk["fn"] += k_true - k_pred
                        elif k_true < k_pred:
                            metrics_mol_andk["fp"] += k_pred - k_true

                        matched["matched_k"] = k_true == k_pred
                        
                    matched_list.append(matched)
                    break
    
    # count number of remaining true mols
    for kmol in k_mol_true_list:
        if not kmol["matched"]:
            metrics_mol_only["fn"] += 1
            
            # if coefficient is None, this is not  valid molecules for 
            # exact match metrics, therefore it is ignored
            metrics_mol_andk["fn"] += kmol["k"] if kmol["k"] else 0
    
    # count number of remaining pred mols
    for kmol in k_mol_pred_list:
        if not kmol["matched"]:
            metrics_mol_only["fp"] += 1
            
            # same as above, if coef is None we do not consider it for exact match
            metrics_mol_andk["fp"] += kmol["k"] if kmol["k"] else 0
    
    # compute all statistics 
    #print("_MATCHED_:")
    #pprint(matched_list)
    
    #print("METRICS MOL ONLY:")
    #pprint(metrics_mol_only)

    #print("METRICS MOL ANDK:")
    #pprint(metrics_mol_andk)
    
    tp = metrics_mol_only["tp"]
    fp = metrics_mol_only["fp"]
    fn = metrics_mol_only["fn"]
        
    results_mol_only = {
        "accuracy": 0.0 if (tp + fn + fp) == 0 else tp / (tp + fn + fp),
        "precision": 0.0 if (tp + fp) == 0 else tp / (tp + fp),
        "recall": 0.0 if (tp + fn) == 0 else tp / (tp + fn),
        "f1_score": 0.0 if (2*tp + fp + fn) == 0 else (2*tp) / (2*tp + fp + fn),
        "hamming_loss": 0.0 if (fp + fn + tp) == 0 else (fp + fn) / (fp + fn + tp)
    }

    tp = metrics_mol_andk["tp"]
    fp = metrics_mol_andk["fp"]
    fn = metrics_mol_andk["fn"]
    
    results_mol_andk = {
        "accuracy": 0.0 if (tp + fn + fp) == 0 else tp / (tp + fn + fp),
        "precision": 0.0 if (tp + fp) == 0 else tp / (tp + fp),
        "recall": 0.0 if (tp + fn) == 0 else tp / (tp + fn),
        "f1_score": 0.0 if (2*tp + fp + fn) == 0 else (2*tp) / (2*tp + fp + fn),
        "hamming_loss": 0.0 if (fp + fn + tp) == 0 else (fp + fn) / (fp + fn + tp)
    }

    exact_match = False
    if fp == 0 and fn == 0:
        exact_match = True
    
    #print("RESULT MOL ONLY:")
    #pprint(results_mol_only)

    #print("RESULT MOL ANDK:")
    #pprint(results_mol_andk)    
    
    #input("next iteration...")
    
    return exact_match, results_mol_only, results_mol_andk

# compute coefficient and molecular accuracies of formulas
def accuracies_top_k(mol_matching_fn, true, preds):
    exact_match_list = []
    results_mol_only_list = []
    results_mol_andk_list = []
    
    for pred in preds:
        exact_match, results_mol_only, result_mol_andk = accuracies_top_1(mol_matching_fn, true, pred)
        exact_match_list.append(exact_match)
        results_mol_only_list.append(results_mol_only)
        results_mol_andk_list.append(result_mol_andk)
    
    return exact_match_list, results_mol_only_list, results_mol_andk_list

def get_coef_and_mol(split, representation):
    # we assume that the correct format is {N}MOLECULE, where N is a positive number
    # and MOLECULE is either a chemical formula or a SMILES string.
    # The stechiometric coefficient {N} must always be at the beginning of the string
    coef_pattern = re.compile("{[0-9]+}")
    coef_match = coef_pattern.match(split)

    # if cannot find coefficient, then consider all string as the molecule
    # otherwise (standard case), the molecule lies at the right of coefficient
    if not coef_match:
        coef_str = None
        mol_str = split        
    else:
        mol_str = split[coef_match.end():]
        coef_str = int(coef_match.group()[1:-1])
        
    try:
        if representation == "smiles":
            mol_obj = Chem.MolFromSmiles(mol_str)
            mol_str = Chem.MolToSmiles(mol_obj)
            mol_str = Chem.CanonSmiles(mol_str)
    except:
        #print("mol_str:", mol_str)
        return None
    
    
    return (coef_str, mol_str)



def compute_accuracy(prediction_path, data_path_src, data_path_tgt, representation, top_k=1, beam_width=1, test_split="test-in"):
    with open(data_path_src, 'r') as fp:
        source_lines = fp.readlines()
    with open(data_path_tgt, 'r') as fp:
        target_lines = fp.readlines()
    with open(prediction_path, "r") as fp:
        pred_lines = fp.readlines()
    
    assert representation in ["smiles", "formula"]

    remove_whitespaces = str.maketrans('', '', string.whitespace)  # used for removing whitespaces.

    # predictions now have a list length of beam_size * test_size, reshape it
    new_pred_lines = []
    for c, pred in enumerate(pred_lines):
        if c % beam_width == 0:
            new_pred_lines.append([pred])
        else:
            new_pred_lines[-1].append(pred)

    assert top_k <= beam_width

    # keep only the top_k prediction fo reach line
    for i in range(len(new_pred_lines)):
        new_pred_lines[i] = new_pred_lines[i][:top_k]

    # now for each prediciton line we have a list of predictions, corresponding to the top_k predictions
    
    print("source lines:", len(source_lines), "target_lines:", len(target_lines), "pred_lines:", len(new_pred_lines))
    
    assert len(source_lines) == len(target_lines) == len(new_pred_lines)
    
    final_results = {
        "EM": [],
        "only_mol": {
            "ACC": [],
            "PRE": [],
            "REC": [],
            "F1": [],
            "HL": []
        },
       "mol_andk": {
            "ACC": [],
            "PRE": [],
            "REC": [],
            "F1": [],
            "HL": []
        }
    }
    
    for source, target, top_k_predictions in tqdm.tqdm(zip(source_lines, target_lines, new_pred_lines), total=len(target_lines)):        
        
        # preprocessing: removes whitespaces
        source = source.translate(remove_whitespaces)
        target = target.translate(remove_whitespaces)
        top_k_preds = [ pred.translate(remove_whitespaces) for pred in top_k_predictions ]

        source_splits = source.split(".")
        target_splits = target.split(".")
        top_k_preds_splits = [ pred.split(".") for pred in top_k_preds ]
        
        
        # parse the lines, identifying coefficient and molecules
        # filter out molecules that cannot be parsed
        # (for which get_coef_and_mol(kmol) == None)
        src_k_mol = [
            get_coef_and_mol(kmol, representation)
            for kmol in source_splits
        ]
        src_k_mol = list(filter(lambda x: x, src_k_mol))

        tgt_k_mol = [
            get_coef_and_mol(kmol, representation)
            for kmol in target_splits
        ]
        tgt_k_mol = list(filter(lambda x: x, tgt_k_mol))
        
        pred_k_mols = [
            list(filter(lambda x: x,
            [
                get_coef_and_mol(kmol, representation)
                for kmol in pred_splits
            ]
            ))
            for pred_splits in top_k_preds_splits
        ]
     
        #print("PRED K MOLS::")
        #pprint(pred_k_mols)
        
        # the matcvhing function must determine wether two strings
        # represent the same molecule
        
        def smiles_matching_fn(mol_true, mol_pred):
            return Chem.CanonSmiles(mol_true) == Chem.CanonSmiles(mol_pred)
        
        def formula_matching_fn(mol_true, mol_pred):
            return mol_true == mol_pred
        
        if representation == "smiles":
            matching_fn = smiles_matching_fn
        else:
            matching_fn = formula_matching_fn
        
        exact_matches, metrics_only_mol, metrics_mol_andk = accuracies_top_k(
            matching_fn,
            tgt_k_mol,
            pred_k_mols
        )
        
        
        # TODO for the time being consider only top-1
        final_results["EM"].append(exact_matches[0])
        
        final_results["only_mol"]["ACC"].append(metrics_only_mol[0]["accuracy"])
        final_results["only_mol"]["PRE"].append(metrics_only_mol[0]["precision"])
        final_results["only_mol"]["REC"].append(metrics_only_mol[0]["recall"])
        final_results["only_mol"]["F1"].append(metrics_only_mol[0]["f1_score"])
        final_results["only_mol"]["HL"].append(metrics_only_mol[0]["hamming_loss"])
        
        final_results["mol_andk"]["ACC"].append(metrics_mol_andk[0]["accuracy"])
        final_results["mol_andk"]["PRE"].append(metrics_mol_andk[0]["precision"])
        final_results["mol_andk"]["REC"].append(metrics_mol_andk[0]["recall"])
        final_results["mol_andk"]["F1"].append(metrics_mol_andk[0]["f1_score"])
        final_results["mol_andk"]["HL"].append(metrics_mol_andk[0]["hamming_loss"])
    
    final_results["EM"] = np.array(final_results["EM"]).mean()
    
    final_results["only_mol"]["ACC"] = np.array(final_results["only_mol"]["ACC"]).mean()
    final_results["only_mol"]["PRE"] = np.array(final_results["only_mol"]["PRE"]).mean()
    final_results["only_mol"]["REC"] = np.array(final_results["only_mol"]["REC"]).mean()
    final_results["only_mol"]["F1"] = np.array(final_results["only_mol"]["F1"]).mean()
    final_results["only_mol"]["HL"] = np.array(final_results["only_mol"]["HL"]).mean()
    
    final_results["mol_andk"]["ACC"] = np.array(final_results["mol_andk"]["ACC"]).mean()
    final_results["mol_andk"]["PRE"] = np.array(final_results["mol_andk"]["PRE"]).mean()
    final_results["mol_andk"]["REC"] = np.array(final_results["mol_andk"]["REC"]).mean()
    final_results["mol_andk"]["F1"] = np.array(final_results["mol_andk"]["F1"]).mean()
    final_results["mol_andk"]["HL"] = np.array(final_results["mol_andk"]["HL"]).mean()
    
    print("*** FINAL RESULTS for", test_split,"***")
    pprint(final_results)
    
    print("ALL GOOD!")
    
    
###################################   
#          ### MAIN ###           #
###################################

# model_name = "TRANS_FINAL_RUN"
model_name = "TRANS_FINAL_RUN_CROSS"

_TYPE = 2
N_AUGM = 5
REPRESENTATION = "smiles"

TOP_K = 3
BEAM_WIDTH = 10

dataset_name = "USPTO_augmented_type" + str(_TYPE) + "_x" + str(N_AUGM) + "_" + REPRESENTATION
predictions_path_prefix = os.path.join("./ChemAlgebra_results", "final_predictions", model_name + "_on_" + dataset_name)
data_path_prefix = os.path.join("./ChemAlgebra_results", "data", dataset_name)

#predictions_path_prefix = "./ChemAlgebra_results/final_predictions/TRANS_STANDARD_ARCH_smiles_type1_x1_on_USPTO_augmented_type1_x1_smiles"
#data_path_prefix = "./ChemAlgebra_results/data/USPTO_augmented_type1_x1_smiles"

if True:
    if False:
        compute_accuracy(
            predictions_path_prefix + "_test_in_final.txt",
            data_path_prefix + "/src-test_in_x1.txt",
            data_path_prefix + "/tgt-test_in_x1.txt",
            top_k=TOP_K,
            representation=REPRESENTATION,
            beam_width=BEAM_WIDTH,
            test_split="test_in"
        )
    
        compute_accuracy(
            predictions_path_prefix + "_test_out_final.txt",
            data_path_prefix + "/src-test_out_x1.txt",
            data_path_prefix + "/tgt-test_out_x1.txt",
            top_k=TOP_K,
            representation=REPRESENTATION,
            beam_width=BEAM_WIDTH,
            test_split="test_out"
        )    
    if True:
        compute_accuracy(
            predictions_path_prefix + "_test_cross_final.txt",
            data_path_prefix + "/src-test_cross.txt",
            data_path_prefix + "/tgt-test_cross.txt",
            top_k=TOP_K,
            representation=REPRESENTATION,
            beam_width=BEAM_WIDTH,
            test_split="test_cross"
        )
