In [6]:
import sys
sys.path.append('/auto/home/filya/3DMolGen')
import os
import os.path as osp

import matplotlib.pyplot as plt
import numpy as np
import ast
from loguru import logger as log
from tqdm import tqdm  


import datamol as dm
from rdkit import Chem
from rdkit.Chem import AllChem

from molgen3D.data_processing.preprocessing_forked_ET_Flow import load_pkl, load_json, embed_coordinates
from molgen3D.evaluation.inference import parse_molecule_with_coordinates
from get_cartesian_from_spherical import parse_molecule_with_spherical_coordinates
from get_spherical_from_cartesian import embed_coordinates_spherical

import importlib
import get_cartesian_from_spherical as to_reload1
import get_spherical_from_cartesian as to_reload2
importlib.reload(to_reload1)
importlib.reload(to_reload2)

# import importlib
# importlib.reload(get_cartesian_from_spherical)

embedding_func_selector = {
    "cartesian": embed_coordinates,
    "spherical": embed_coordinates_spherical
}
decoding_func_selector = {
    "cartesian": parse_molecule_with_coordinates,
    "spherical": parse_molecule_with_spherical_coordinates
}

def stat_log(rmsds):
    max_rmsd = np.max(rmsds)
    mean_rmsd = np.mean(rmsds)
    percentile95 = np.percentile(rmsds, 95)
    log.info(f"Maximum RMSD: {max_rmsd}")
    log.info(f"Mean RMSD: {mean_rmsd}")
    log.info(f"95th Percentile: {percentile95}")
    # plt.figure(figsize=(8, 6))
    # plt.hist(rmsds, bins=50, alpha=0.7, color='blue', edgecolor='black')
    # plt.xlabel("RMSD")
    # plt.ylabel("Frequency")
    # plt.title("Distribution of RMSDs")
    # plt.show()
    # print(np.sort(rmsds)[::-1])

def validate(raw_path, embedding_type, limit, precision, single = False, cur_id = -1): #dest_folder_path, indices_path,
    # partitions = ["qm9", "drugs"]
    partitions = ["qm9"]
    rmsds = []
    embedding_function = embedding_func_selector[embedding_type]
    decoding_function = decoding_func_selector[embedding_type]
    if not os.path.exists("molecules"):
        os.makedirs("molecules")
    writer = Chem.SDWriter("molecules/special_mol1.sdf")
    writer1 = Chem.SDWriter("molecules/special-rec1.sdf")
    writer_error = Chem.SDWriter("molecules/error1.sdf")

    for partition in partitions:
        # dest_path = osp.join(dest_folder_path, partition.upper())
        # train_indices = set(sorted(np.load(osp.join(*[indices_path, partition.upper(),
        #                                                 "train_indices.npy"]), allow_pickle=True)))
        # val_indices = set(sorted(np.load(osp.join(*[indices_path, partition.upper(),
        #                                             "val_indices.npy"]), allow_pickle=True)))
        # log.info(f"{partition} indices contain train:{len(train_indices)}, valid:{len(val_indices)},"\
        #             f" total:{len(train_indices)+len(val_indices)} samples")
        mols = load_json(osp.join(raw_path, f"summary_{partition}.json"))
        for id, (mol_id, mol_dict) in tqdm(
            enumerate(mols.items()),
            total=len(mols),
            desc=f"Processing molecules of {partition}",
        ):
            if id < 22165:
                continue
            if id >= limit:
                break
            if id % 1000 == 999:
                stat_log(rmsds)
            mol_pickle = load_pkl(os.path.join(raw_path, mol_dict["pickle_path"]))
            confs = mol_pickle["conformers"]    
            if single and id != cur_id:
                continue   
            try: 
                for conf in confs:
                    mol, geom_id = conf["rd_mol"], conf["geom_id"]
                    # suppl = Chem.SDMolSupplier("/auto/home/filya/3DMolGen/molgen3D/data_processing/mol_9496-orig.sdf")
                    # mol = suppl[0]
                    canonical_smiles = dm.to_smiles(
                        mol,
                        canonical=True,
                        explicit_hs=True,
                        with_atom_indices=False,
                        isomeric=False,
                    )
                    if '.' in canonical_smiles:
                        continue
                    atom_order = list(map(int, ast.literal_eval(mol.GetProp('_smilesAtomOutputOrder'))))
                    print(atom_order)
                    
                    embedded_smiles = embedding_function(mol, canonical_smiles, atom_order, precision)
                    
                    # sample = {"canonical_smiles": canonical_smiles,
                    #             "geom_embed_coordinatesid": geom_id, 
                    #             "embedded_smiles": embedded_smiles}
                    mol1 = decoding_function(embedded_smiles)
                    rmsd = AllChem.GetBestRMS(mol, mol1)
                    rmsds.append(rmsd)
                    if rmsd > 0.3:
                        writer.write(mol)
                        writer1.write(mol1)
                        log.info(canonical_smiles)
                        log.info(len(confs))
                        log.info(id)
                        log.info(rmsd)
                        

            except Exception as e:
                log.error(f"Error: {e} for molecule {canonical_smiles}, id {id}")
                writer_error.write(mol)
    
    writer.close()
    writer1.close()
    writer_error.close()
    
    rmsds = np.array(rmsds)
    return rmsds



# Parameters for preprocessing
raw_path = "/mnt/sxtn2/chem/GEOM_data/rdkit_folder"
# dest_folder_path = "geom"
# os.makedirs(dest_folder_path, exist_ok=True)
# indices_path = "/mnt/sxtn2/chem/GEOM_data/et_flow_indice/"
embedding_type = "spherical"
limit = 22166

# rmsds = validate(raw_path, embedding_type, limit, precision=4, single=True, cur_id=994) # dest_folder_path, indices_path
rmsds = validate(raw_path, embedding_type, limit, precision=4)
stat_log(rmsds)

Processing molecules of qm9:   0%|          | 0/133258 [00:00<?, ?it/s][32m2025-03-11 14:23:05.439[0m | [31m[1mERROR   [0m | [36mget_spherical_from_cartesian[0m:[36mcalculate_descriptors[0m:[36m191[0m - [31m[1mc2 was not found for atom 18, 9[0m
[32m2025-03-11 14:23:05.446[0m | [1mINFO    [0m | [36m__main__[0m:[36mvalidate[0m:[36m119[0m - [1m[H][C]([H])([H])[C](=[O])[C]#[C][C]([H])([H])[C]([H])([C]([H])([H])[H])[C]([H])([H])[H][0m
[32m2025-03-11 14:23:05.447[0m | [1mINFO    [0m | [36m__main__[0m:[36mvalidate[0m:[36m120[0m - [1m6[0m
[32m2025-03-11 14:23:05.447[0m | [1mINFO    [0m | [36m__main__[0m:[36mvalidate[0m:[36m121[0m - [1m22165[0m
[32m2025-03-11 14:23:05.447[0m | [1mINFO    [0m | [36m__main__[0m:[36mvalidate[0m:[36m122[0m - [1m1.2102655901622865[0m
Processing molecules of qm9:  17%|█▋        | 22166/133258 [00:00<00:00, 435691.69it/s]
[32m2025-03-11 14:23:05.505[0m | [1mINFO    [0m | [36m__main__[0m:[36mstat_log

[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
[15, 14, 16, 17, 12, 13, 11, 10, 9, 18, 19, 3, 8, 4, 5, 6, 7, 0, 1, 2, 20]
