In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
import io
import copy
from tqdm import tqdm
from PIL import ImageFont, ImageDraw
import matplotlib.pyplot as plt
from rdkit import Chem, rdBase
from rdkit.Chem import Draw
import datasets 
from pprint import pprint
from PIL import Image 
import cairosvg

from mol_depict_cdk.generation import get_boxes, get_cells
from mol_depict_cdk.generate_hf_dataset_ocr_boxes import CXSMILESTokenizer

font = ImageFont.truetype("../../data/fonts/arial.ttf", 50)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# Read for dataloader
dataset_name_hf = "ocxsr_3002"
dataset_hf = datasets.load_from_disk(os.getcwd() + f"/../../data/hf_dataset/{dataset_name_hf}/", keep_in_memory=False)["train"]
dataset_hf

Dataset({
    features: ['id', 'image_path', 'mol', 'cxsmiles', 'cxsmiles_dataset', 'cxsmiles_opt', 'keypoints', 'cells', 'image'],
    num_rows: 63
})

### Visualize HuggingFace dataset

In [None]:
# Visualize from Hugging Face dataset
i_max = 15
for i, sample in enumerate(dataset_hf.iter(batch_size=1)):
    if i > i_max:
        break
    id, image, mol, cxsmiles, cxsmiles_dataset, cxsmiles_opt, keypoints, cells = sample["id"][0], sample["image"][0], sample["mol"][0], sample["cxsmiles"][0], sample["cxsmiles_dataset"][0], sample["cxsmiles_opt"][0], sample["keypoints"][0], sample["cells"][0]
    # Select only molecules with explicit hydrogens
    # selected = False
    # for i, (ocr_cell) in enumerate(cells):
    #     # Display cells and atom mapping
    #     ocr_bbox = [
    #         ocr_cell["bbox"][0]*image.size[0],
    #         ocr_cell["bbox"][1]*image.size[1],
    #         ocr_cell["bbox"][2]*image.size[0],
    #         ocr_cell["bbox"][3]*image.size[1],
    #     ]
    #     if "H" in ocr_cell["text"]:
    #         selected = True
    # if not(selected):
    #     continue
        
    # Select only molecules with sg
    # if not("Sg" in cxsmiles):
    #     continue
    print(i)
    print(id)
    print(cxsmiles_dataset)
    print(cxsmiles)
    print(cxsmiles_opt)
    image = copy.deepcopy(image)
    draw = ImageDraw.Draw(image)

    for i, (ocr_cell) in enumerate(cells):
        # Display cells and atom mapping
        ocr_bbox = [
            ocr_cell["bbox"][0]*image.size[0],
            ocr_cell["bbox"][1]*image.size[1],
            ocr_cell["bbox"][2]*image.size[0],
            ocr_cell["bbox"][3]*image.size[1],
        ]
        draw.rectangle(((ocr_bbox[0], ocr_bbox[1]), (ocr_bbox[2], ocr_bbox[3])), outline="red", width=5)
        draw.text((ocr_bbox[0], ocr_bbox[1]), str(i), (0, 0, 255), font=font)
        print(i, ocr_cell["text"])
    plt.figure(figsize=(15,15))
    plt.imshow(image)
    plt.show()
    plt.close()

In [None]:
invalid_i = []
for i, sample in tqdm(enumerate(dataset_hf.iter(batch_size=1)), total=len(dataset_hf)):
    if "*" in sample["cxsmiles_opt"][0]:
        invalid_i.append(i)

## Debugging

In [5]:
dataset_name = "experiment-cx004_cxsmiles_ocr"

### Debug CXSMILES optimization 

In [None]:
for i in range(10, 11):
    id = dataset_hf["id"][i]
    print(id)
    molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
    original_cxsmiles = dataset_hf['cxsmiles_dataset'][i]
    print(f"Original CXSMILES: {original_cxsmiles}")
    with rdBase.BlockLogs() as context:
        m = Chem.MolFromMolFile(molfile_path, strictParsing=False, removeHs=False)
        if m is None:
            print("Invalid CXSMILES from MOLfile")
            continue
        cxsmiles_mol = Chem.MolToCXSmiles(m)
        mol_to_cxsmi_i_mapping = {k: v for k, v in zip(
            list(map(int, m.GetProp("_smilesAtomOutputOrder")[1:-2].split(","))),
            range(0, m.GetNumAtoms()),
        )}
    print(f"Original CXSMILES from MOLfile: {cxsmiles_mol}")

    original_r_labels = [c for c in original_cxsmiles.split("|")[1].split("$")[1].split(";") if c != ""]
    if not(all([r in cxsmiles_mol for r in original_r_labels])):
       print("Invalid CXSMILES from MOLfile")
       continue
    # Convert molfile to cxsmiles opt
    cxsmiles_tokenizer = CXSMILESTokenizer()
    cxsmiles_opt, keypoints = cxsmiles_tokenizer.convert_cdk_to_opt(cxsmiles_mol, molfile_path, mol_to_cxsmi_i_mapping)
    print(f"Optimized CXSMILES: {cxsmiles_opt}")
     # Convert cxsmiles opt to cxsmiles
    cxsmiles_pred = cxsmiles_tokenizer.convert_opt_to_out(cxsmiles_opt)
    print(f"Reconstructed CXSMILES: {cxsmiles_pred}")
    
    # Display molecules
    parser_params = Chem.SmilesParserParams()
    parser_params.strictCXSMILES = False
    m = Chem.MolFromSmiles(cxsmiles_mol, parser_params)
    for atom in m.GetAtoms():
        atom.SetProp("atomNote", str(atom.GetIdx()+1))
    display(Draw.MolToImage(m, size=(450, 450)))

    m_reconstructed = Chem.MolFromSmiles(cxsmiles_pred, parser_params)
    for atom in m_reconstructed.GetAtoms():
        atom.SetProp("atomNote", str(atom.GetIdx()+1))
    display(Draw.MolToImage(m_reconstructed, size=(450, 450)))

    # Display image
    image_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/images/{id}.svg"
    display(Image.open(io.BytesIO(cairosvg.svg2png(url=image_path))))
    

### Debug cells generation

In [None]:
hf_dataset_i = 47
id = 645256
svg_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/images/{id}.svg"
molfile_path = os.getcwd() + f"/../../data/dataset/{dataset_name}/molfiles/{id}.mol"
cxsmiles = dataset_hf["cxsmiles_dataset"][hf_dataset_i]
print(cxsmiles)
atom_boxes, smt_boxes = get_boxes(svg_path) 
pprint(atom_boxes)
pprint(smt_boxes)
print("\n")
cells = get_cells(cxsmiles, molfile_path, atom_boxes, smt_boxes) 
pprint(cells)
display(Image.open(io.BytesIO(cairosvg.svg2png(url=svg_path))))
print(open(svg_path, 'r').read())