In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
import random
import datasets 
from tqdm import tqdm
from datasets import concatenate_datasets
from rdkit import Chem, rdBase

from mol_depict_cdk.cxsmiles_tokenizer import CXSMILESTokenizer

  from .autonotebook import tqdm as notebook_tqdm


### Remove CXSMILES opt to CXSMILES out conversion errors 

In [7]:
dataset_name = "ocxsr_3004"

In [8]:
# Read for dataloader
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
print(dataset_hf)

Dataset({
    features: ['id', 'image_path', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'keypoints', 'cells', 'image', 'cxsmiles_opt'],
    num_rows: 217575
})


In [9]:
i_max = float("inf")
verify = False
remove_indices = []
cxsmiles_tokenizer = CXSMILESTokenizer()
for i, sample in tqdm(enumerate(dataset_hf.iter(batch_size=1)), total=min(i_max, len(dataset_hf))):
    if i > i_max:
        break
    id, image, mol, cxsmiles, cxsmiles_dataset, cxsmiles_opt, keypoints, cells = sample["id"][0], sample["image"][0], sample["mol"][0], sample["cxsmiles"][0], sample["cxsmiles_dataset"][0], sample["cxsmiles_opt"][0], sample["keypoints"][0], sample["cells"][0]
    
    if "*" in cxsmiles_opt:
        remove_indices.append(i)
        continue
    if not(verify):
        continue
    try:
        cxsmiles_out = cxsmiles_tokenizer.convert_opt_to_out(cxsmiles_opt)
        molecule = Chem.MolFromSmiles(cxsmiles_out)
        if molecule is None:
            print(cxsmiles_opt)
            remove_indices.append(i)
            continue
    except:
        print("CXSMILES dataset:", cxsmiles_dataset)
        print("CXSMILES CDK:", cxsmiles)
        print("CXSMILES optimized:", cxsmiles_opt)
        break
print(remove_indices)

  0%|          | 8/217575 [00:00<46:52, 77.37it/s]

100%|██████████| 217575/217575 [34:05<00:00, 106.36it/s]

[]





In [10]:
dataset_hf = dataset_hf.select((i for i in range(len(dataset_hf)) if i not in set(remove_indices)))



In [11]:
hf_dataset_name = "ocxsr_3005"
dataset_hf = dataset_hf.train_test_split(test_size=0.1)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (0/18 shards):   0%|          | 0/195817 [00:00<?, ? examples/s]

Saving the dataset (18/18 shards): 100%|██████████| 195817/195817 [00:36<00:00, 5333.96 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 21758/21758 [00:04<00:00, 5140.24 examples/s]


### Remove CXSMILES with multiple Sg sections on the same minimum or maximum atom indices

In [8]:
dataset_name = "ocxsr_2001"

In [9]:
# Read for dataloader
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
print(dataset_hf)

Dataset({
    features: ['id', 'image_path', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'cxsmiles_opt', 'keypoints', 'cells', 'image'],
    num_rows: 231996
})


In [10]:
remove_indices = []

cxsmiles_tokenizer = CXSMILESTokenizer()
for i, sample in tqdm(enumerate(dataset_hf.iter(batch_size=1)), total=min(i_max, len(dataset_hf))):
    min_indices, max_indices = [], []
    for i_sample, section in enumerate(cxsmiles_tokenizer.parse_sections(sample["cxsmiles_dataset"][0].split("|")[1])):
        if (len(section) >= 2) and (section[:2] == "Sg"): 
            sg_section = cxsmiles_tokenizer.parse_sg_section(section)
            indices = []
            for index in sg_section[2:]:
                if index == "<atom_list_end>":
                    break
                if index == ",":
                    continue
                indices.append(int(index))
            min_index, max_index = min(indices), max(indices)
            if (min_index in min_indices) or (max_index in max_indices):
                remove_indices.append(i)
                break
            min_indices.append(min_index)
            max_indices.append(max_index)

print(remove_indices)

  0%|          | 11/231996 [00:00<38:31, 100.36it/s]

100%|██████████| 231996/231996 [36:16<00:00, 106.61it/s]

[89, 472, 1208, 1370, 1431, 1458, 1910, 1982, 2663, 3051, 3449, 3475, 3668, 3988, 4277, 4418, 5398, 5540, 5946, 5994, 6120, 6788, 6860, 7328, 7559, 8196, 8314, 8381, 8576, 8736, 9155, 9385, 9990, 11787, 11893, 12183, 12394, 12918, 13077, 13122, 13131, 13337, 13381, 13385, 13524, 14376, 14978, 15040, 15820, 15826, 15965, 16131, 16206, 16618, 16989, 17111, 17279, 17583, 17909, 18166, 18471, 18542, 18692, 18732, 19017, 19485, 21049, 21866, 21885, 21941, 22035, 23066, 23721, 23788, 23825, 24067, 24165, 24612, 24791, 24883, 24925, 25221, 25313, 25396, 25633, 26299, 26347, 26372, 27233, 27660, 27827, 28791, 28996, 29373, 29578, 30050, 30336, 30343, 30497, 30535, 30577, 31068, 31188, 31313, 31336, 31479, 31696, 31915, 32564, 33291, 33537, 33590, 34185, 34367, 34613, 35059, 35351, 35770, 35934, 36257, 36683, 36910, 37003, 37493, 37988, 38281, 38390, 38420, 38447, 38530, 38763, 38947, 39080, 39224, 39491, 39614, 40532, 40936, 41036, 41121, 41974, 42091, 43013, 44369, 44788, 44973, 45291, 45548,




In [11]:
dataset_hf = dataset_hf.select((i for i in range(len(dataset_hf)) if i not in set(remove_indices)))

In [12]:
hf_dataset_name = "ocxsr_2002"
dataset_hf = dataset_hf.train_test_split(test_size=0.1)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (0/19 shards):   0%|          | 0/208127 [00:00<?, ? examples/s]

Saving the dataset (19/19 shards): 100%|██████████| 208127/208127 [00:40<00:00, 5178.39 examples/s]
Saving the dataset (3/3 shards): 100%|██████████| 23126/23126 [00:04<00:00, 5296.37 examples/s]


### Clean existing dataset by recomputing cxsmiles opt

In [None]:
# Read for dataloader
dataset_name = "ocxsr_17"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

In [None]:
max_i = float("inf")
dataset_name = "experiment-cx002_cxsmiles_ocr"
cxsmiles_tokenizer = CXSMILESTokenizer()
verbose = False
new_cxsmiles_opt = []
for index, (cxsmiles, cxsmiles_opt, id) in tqdm(enumerate(zip(dataset_hf["cxsmiles"], dataset_hf["cxsmiles_opt"], dataset_hf["id"])), total=len(dataset_hf)):
    if index > max_i:
        break
    # Detect splitted Sg section with more than one comma (ex: Sg:n:1,2,3:l:ht)
    detected_error = False
    with rdBase.BlockLogs():
        parser_params = Chem.SmilesParserParams()
        parser_params.strictCXSMILES = False
        molecule = Chem.MolFromSmiles(cxsmiles, parser_params)
        cxsmiles = Chem.MolToCXSmiles(molecule, canonical=False)

    rtable = cxsmiles.split("|")[1]
    coordinates = rtable[rtable.find("("): rtable.find(")") + 1]
    rtable = rtable.replace(coordinates, "")
    rtable_opt = ""
    rtable_split = rtable.split(",")
    for i, s in enumerate(rtable_split):
        if ("atomProp" in s):
            continue
        if s == "":
            continue
        if not("Sg" in s):
            continue
        if ("Sg" in s):
            parsed_section = [c for c in s.split(":") if c != ""]
            s = ":".join(parsed_section)
            if len(parsed_section) == 3:
                offset = 1
                next_index = rtable_split[i + offset]
                while len(next_index) == 1:
                    s += "," + next_index
                    offset += 1
                    next_index = rtable_split[i + offset]
                    if offset >= 2:
                        detected_error = True
    if not(detected_error):
        new_cxsmiles_opt.append(cxsmiles_opt)
        continue
    
    molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
    with rdBase.BlockLogs():
        molecule = Chem.MolFromMolFile(molfile_path, strictParsing=False, removeHs=False)
    if molecule is None:
        if verbose:
            print("Invalid CXSMILES from MOLfile")
        continue
    cxsmiles = Chem.MolToCXSmiles(molecule)
    mol_to_cxsmi_i_mapping = {k: v for k, v in zip(
        list(map(int, molecule.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))),
        range(0, molecule.GetNumAtoms()),
    )}
    
    cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles, molfile_path, mol_to_cxsmi_i_mapping)

    new_cxsmiles_opt.append(cxsmiles_opt)
    if verbose:
        print(f"Problem fixed for {index, id}")

In [9]:
dataset_hf = dataset_hf.remove_columns("cxsmiles_opt").add_column("cxsmiles_opt", new_cxsmiles_opt)

In [10]:
hf_dataset_name = "ocxsr_18"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (5/5 shards): 100%|██████████| 53790/53790 [00:53<00:00, 1009.17 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 13448/13448 [00:03<00:00, 4442.88 examples/s]


### Clean existing dataset by recomputing cxsmiles opt 2

In [4]:
# Read for dataloader
dataset_name = "ocxsr_17"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

Dataset({
    features: ['id', 'image', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'keypoints', 'cxsmiles_opt', 'celss'],
    num_rows: 67238
})

In [None]:
max_i = float("inf")
dataset_name = "experiment-cx002_cxsmiles_ocr"
cxsmiles_tokenizer = CXSMILESTokenizer()
verbose = True
new_cxsmiles_opt = []
for index, (cxsmiles, cxsmiles_opt, id) in tqdm(enumerate(zip(dataset_hf["cxsmiles"], dataset_hf["cxsmiles_opt"], dataset_hf["id"])), total=len(dataset_hf)):
    if index > max_i:
        break
    molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
    with rdBase.BlockLogs():
        molecule = Chem.MolFromMolFile(molfile_path, strictParsing=False, removeHs=False)
    if molecule is None:
        print("Invalid CXSMILES from MOLfile")
        break
    
    cxsmiles = Chem.MolToCXSmiles(molecule)
    mol_to_cxsmi_i_mapping = {k: v for k, v in zip(
        list(map(int, molecule.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))),
        range(0, molecule.GetNumAtoms()),
    )}
    
    cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles, molfile_path, mol_to_cxsmi_i_mapping)

    # Check that Sg section
    gt_smiles = cxsmiles_tokenizer.convert_opt_to_out(cxsmiles_opt)
    
    canonical_smiles = canonicalize_markush(gt_smiles)
    if "eu" in canonical_smiles:
        print("index:", index)
        print("cxsmiles_opt:", cxsmiles_opt)
        print("cxsmiles:", cxsmiles)

        cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles, molfile_path, mol_to_cxsmi_i_mapping, verbose=True)
        break

    new_cxsmiles_opt.append(cxsmiles_opt)
    

In [28]:
dataset_hf = dataset_hf.remove_columns("cxsmiles_opt").add_column("cxsmiles_opt", new_cxsmiles_opt)

In [29]:
hf_dataset_name = "ocxsr_20"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (5/5 shards): 100%|██████████| 53790/53790 [00:32<00:00, 1633.96 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 13448/13448 [00:02<00:00, 4822.16 examples/s]


### Clean existing dataset by recomputing cxsmiles opt 3

In [3]:
# Read for dataloader
dataset_name = "ocxsr_3003"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

Dataset({
    features: ['id', 'image_path', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'keypoints', 'cells', 'image', 'cxsmiles_opt'],
    num_rows: 217575
})

In [4]:
max_i = float("inf")
dataset_name = "experiment-cx3000_cxsmiles_ocr"
cxsmiles_tokenizer = CXSMILESTokenizer()
verbose = True
new_cxsmiles_opt = []
for index, (cxsmiles, cxsmiles_opt, id) in tqdm(enumerate(zip(dataset_hf["cxsmiles"], dataset_hf["cxsmiles_opt"], dataset_hf["id"])), total=len(dataset_hf)):
    if index > max_i:
        break
    molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
    with rdBase.BlockLogs():
        molecule = Chem.MolFromMolFile(molfile_path, strictParsing=False, removeHs=False)
    if molecule is None:
        print("Invalid CXSMILES from MOLfile")
        break
    
    cxsmiles = Chem.MolToCXSmiles(molecule)
    mol_to_cxsmi_i_mapping = {k: v for k, v in zip(
        list(map(int, molecule.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))),
        range(0, molecule.GetNumAtoms()),
    )}
    
    cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles, molfile_path, mol_to_cxsmi_i_mapping)

    new_cxsmiles_opt.append(cxsmiles_opt)
    

100%|██████████| 217575/217575 [07:43<00:00, 469.45it/s]


In [5]:
dataset_hf = dataset_hf.remove_columns("cxsmiles_opt").add_column("cxsmiles_opt", new_cxsmiles_opt)

In [6]:
hf_dataset_name = "ocxsr_3004"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (16/16 shards): 100%|██████████| 174060/174060 [01:59<00:00, 1452.10 examples/s]
Saving the dataset (4/4 shards): 100%|██████████| 43515/43515 [00:08<00:00, 4949.25 examples/s]


### Shuffle boxes

In [19]:
# Read for dataloader
dataset_name = "ocxsr_16"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

Dataset({
    features: ['id', 'image', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'keypoints', 'cells', 'cxsmiles_opt'],
    num_rows: 67238
})

In [25]:
max_i = float("inf")

new_cells_list = []
for index, cells in enumerate(dataset_hf["cells"]):   
    if index > max_i:
        break
    random.shuffle(cells)
    new_cells_list.append(cells)

In [26]:
dataset_hf = dataset_hf.remove_columns("cells").add_column("cells", new_cells_list)

In [27]:
hf_dataset_name = "ocxsr_17"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (0/5 shards):   0%|          | 0/53790 [00:00<?, ? examples/s]

Saving the dataset (5/5 shards): 100%|██████████| 53790/53790 [00:09<00:00, 5740.62 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 13448/13448 [00:02<00:00, 6019.18 examples/s]


### Rename column "cells"

In [30]:
# Read for dataloader
dataset_name = "ocxsr_20"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

Dataset({
    features: ['id', 'image', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'keypoints', 'celss', 'cxsmiles_opt'],
    num_rows: 67238
})

In [31]:
dataset_hf = dataset_hf.rename_column("celss", "cells")

In [32]:
hf_dataset_name = "ocxsr_21"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (0/5 shards):   0%|          | 0/53790 [00:00<?, ? examples/s]

Saving the dataset (5/5 shards): 100%|██████████| 53790/53790 [00:12<00:00, 4200.27 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 13448/13448 [00:03<00:00, 4405.02 examples/s]


### Recompute cxsmiles opt to remove R-group injection (Ablation study)

In [3]:
# Read for dataloader
dataset_name = "mdu_3008_aug"
dataset_hf = concatenate_datasets([
    datasets.load_from_disk(os.getcwd() + f"/../../../deepsearch-ai-unidoc/data/{dataset_name}/", keep_in_memory=False)["train"],
    datasets.load_from_disk(os.getcwd() + f"/../../../deepsearch-ai-unidoc/data/{dataset_name}/", keep_in_memory=False)["test"]
])
dataset_hf

Dataset({
    features: ['id', 'page_image_path', 'description', 'annotation', 'mol', 'cxsmiles_dataset', 'cxsmiles', 'cxsmiles_opt', 'keypoints', 'cells', 'page_image'],
    num_rows: 235570
})

In [4]:
max_i = float("inf")
dataset_name = "experiment-cx3000_cxsmiles_ocr"
cxsmiles_tokenizer = CXSMILESTokenizer(condense_labels = False)
verbose = False
new_cxsmiles_opt = []
for index, (cxsmiles, cxsmiles_opt, id) in tqdm(enumerate(zip(dataset_hf["cxsmiles"], dataset_hf["cxsmiles_opt"], dataset_hf["id"])), total=len(dataset_hf)):
    if index > max_i:
        break
   
    molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
    with rdBase.BlockLogs():
        molecule = Chem.MolFromMolFile(molfile_path, strictParsing=False, removeHs=False)
    if molecule is None:
        if verbose:
            print("Invalid CXSMILES from MOLfile")
        continue
    cxsmiles = Chem.MolToCXSmiles(molecule)
    mol_to_cxsmi_i_mapping = {k: v for k, v in zip(
        list(map(int, molecule.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))),
        range(0, molecule.GetNumAtoms()),
    )}
    #cxsmiles_out = cxsmiles_tokenizer.convert_opt_to_out(cxsmiles_opt, condensed_labels=True)
    cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles, molfile_path, mol_to_cxsmi_i_mapping)
    new_cxsmiles_opt.append(cxsmiles_opt)    
    
    #cxsmiles_out = cxsmiles_tokenizer.convert_opt_to_out(cxsmiles_opt)
    
    if verbose:
        print(f"Problem fixed for {index, id}")

100%|██████████| 235570/235570 [07:37<00:00, 514.78it/s]


In [5]:
dataset_hf = dataset_hf.remove_columns("cxsmiles_opt").add_column("cxsmiles_opt", new_cxsmiles_opt)

In [6]:
hf_dataset_name = "mdu_3008_aug_no_condense"
dataset_hf = dataset_hf.train_test_split(test_size=0.2)
dataset_hf.save_to_disk(os.getcwd() + f"/../../data/hf_dataset/{hf_dataset_name}/")

Saving the dataset (45/45 shards): 100%|██████████| 188456/188456 [05:58<00:00, 525.19 examples/s] 
Saving the dataset (12/12 shards): 100%|██████████| 47114/47114 [00:42<00:00, 1109.74 examples/s]
