In [None]:
import os
import json
import random
import pickle
from rdkit import Chem
import datamol as dm
import numpy as np
import tqdm
import re
import glob
from os import path as osp
import pickle
# change to where you untarred the rdkit folder
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import ast
# This enables inline rendering of molecules
IPythonConsole.ipython_useSVG=True 
from rdkit.Chem.AllChem import GetBestRMS

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import py3Dmol


In [None]:
def draw_molecule(smi):
    # Generate the molecule
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.AddHs(mol)  # Add hydrogens
    AllChem.EmbedMolecule(mol)  # Generate 3D coordinates

    # Convert to 3Dmol.js format
    block = Chem.MolToMolBlock(mol)
    view = py3Dmol.view(width=400, height=400)
    view.addModel(block, "mol")  # Add the molecule
    view.setStyle({'stick': {}})  # Set stick style

    for atom in mol.GetAtoms():
        pos = mol.GetConformer().GetAtomPosition(atom.GetIdx())
        view.addLabel(atom.GetSymbol(), 
                    {'position': {'x': pos.x, 'y': pos.y, 'z': pos.z}, 
                    'backgroundColor': 'white', 
                    'fontColor': 'black', 
                    'fontSize': 12})
    view.zoomTo()  # Zoom to fit the molecule
    view.show()

In [None]:
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem.rdchem import ChiralType
dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')
chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1.,
             ChiralType.CHI_TETRAHEDRAL_CCW: 1.,
             ChiralType.CHI_UNSPECIFIED: 0,
             ChiralType.CHI_OTHER: 0}

bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
qm9_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
drugs_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10,
               'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20,
               'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30,
               'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34}
def featurize_mol_from_smiles(smiles, dataset='drugs'):

    if dataset == 'qm9':
        types = qm9_types
    elif dataset == 'drugs' or dataset == 'bace':
        types = drugs_types

    # filter fragments
    if '.' in smiles:
        return None

    # filter mols rdkit can't intrinsically handle
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        mol = Chem.AddHs(mol)
    else:
        return None
    N = mol.GetNumAtoms()

    # filter out mols model can't make predictions for
    print(mol.HasSubstructMatch(dihedral_pattern))
    if not mol.HasSubstructMatch(dihedral_pattern):
        return None
    if N < 4:
        return None
    # data = featurize_mol(mol, types)
    # data.name = smiles
    return mol

In [None]:
base_path = "/mnt/sxtn2/chem/GEOM_data"

test_mols_path = os.path.join(base_path, "geom_processed/test_smiles_corrected.csv")
drugs_file_path = os.path.join(base_path, "rdkit_folder/summary_drugs.json")
destination_path = "./drugs_test_inference.jsonl"

with open(drugs_file_path, "r") as f:
    drugs_summ = json.load(f)

with open(test_mols_path, 'r') as f:
    print(f.readline())
    test_mols = [(m.split(',')) for m in f.readlines()]
test_data = [(m[0].strip(), int(m[1]),m[2].strip()) for m in test_mols]

In [None]:
def clean_confs(smi, confs):
    good_ids = []
    smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi, sanitize=False), isomericSmiles=False)
    # print(f"{smi=}")
    for i, c in enumerate(confs):
        conf_smi = Chem.MolToSmiles(Chem.RemoveHs(c, sanitize=False), isomericSmiles=False)
        # print(f"{conf_smi=}")
        if conf_smi == smi:
            good_ids.append(i)
    return [confs[i] for i in good_ids]

In [None]:
def load_pkl(file_path: str):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File {file_path} does not exist.")
    with open(file_path, "rb") as f:
        return pickle.load(f)
    

In [None]:
def find_pickle_files(base_dir, smi):
    """Finds all .pickle files recursively and returns their relative paths."""
    pickle_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".pkl"):
                relative_path = os.path.relpath(os.path.join(root, file), base_dir)
                if smi in relative_path:
                    pickle_files.append(os.path.join(base_dir, relative_path))

    return pickle_files

In [None]:
SMI = "CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1"
SMI = "C#CCOCCOCCOCCNc1nc(N2CCN(C(=O)[C@H](CCC(=O)O)n3cc([C@H](N)CO)nn3)CC2)nc(N2CCN(C(=O)[C@H](CCC(=O)O)n3cc([C@@H]([NH3+])CO)nn3)CC2)n1"
for en, (smiles,n_conformers,corrected_smiles) in enumerate(test_data):
    if corrected_smiles == SMI:
        # pickles = [load_pkl(p) for p in find_pickle_files("/mnt/sxtn2/chem/GEOM_data/GEOM_data_MCF/data/processed_drugs/test_1000", smiles)]
        conformers = [conf['rd_mol'] for conf in load_pkl(f"/mnt/sxtn2/chem/GEOM_data/rdkit_folder/{drugs_summ[smiles]['pickle_path']}")['conformers']]
        print(en, n_conformers, ' -------------------')
        clean_confs_list = clean_confs(corrected_smiles, conformers)
        if len(clean_confs_list) < n_conformers:
            print("Found conformers with different SMILES")
            print(f"Original SMILES: {smiles}")
            print(f"Corrected SMILES: {corrected_smiles}")
            print(f"Cleaned Conformers: {len(  clean_confs_list)}")
            print(f"Original Conformers: {n_conformers}")


In [None]:
bad_smi_list, bad_smi, zero_smi = [], 0,0
for en, (smiles,n_conformers,corrected_smiles) in enumerate(test_data):
    # pickles = [load_pkl(p) for p in find_pickle_files("/mnt/sxtn2/chem/GEOM_data/GEOM_data_MCF/data/processed_drugs/test_1000", smiles)]
    conformers = [conf['rd_mol'] for conf in load_pkl(f"/mnt/sxtn2/chem/GEOM_data/rdkit_folder/{drugs_summ[smiles]['pickle_path']}")['conformers']]
    print(en, n_conformers, ' -------------------')
    clean_confs_list = clean_confs(corrected_smiles, conformers)
    if len(clean_confs_list) < n_conformers:
        bad_smi += 1
        if len(clean_confs_list) == 0:
            zero_smi += 1
        print("Found conformers with different SMILES")
        print(f"Original SMILES: {smiles}")
        print(f"Corrected SMILES: {corrected_smiles}")
        print(f"Cleaned Conformers: {len(  clean_confs_list)}")
        print(f"Original Conformers: {n_conformers}")
        bad_smi_list.append((smiles,n_conformers,corrected_smiles))
print("bad_smi", bad_smi)
print("zero_smi", zero_smi)

In [None]:
for smiles,n_conformers,corrected_smiles in bad_smi_list:
    print(f"Original SMILES: {smiles}")
    print(f"Corrected SMILES: {corrected_smiles}")
    print(f"Original Conformers: {n_conformers}")
    draw_molecule(corrected_smiles)
    draw_molecule(smiles)
    print("--------------------------------------------------")

In [None]:
bad_smi_iso, bad_smi_noiso, bad_smi_chi = [], [], []
for en, (smiles,n_conformers,corrected_smiles) in enumerate(test_data):
    # pickles = [load_pkl(p) for p in find_pickle_files("/mnt/sxtn2/chem/GEOM_data/GEOM_data_MCF/data/processed_drugs/test_1000", smiles)]
    conformers = load_pkl(f"/mnt/sxtn2/chem/GEOM_data/rdkit_folder/{drugs_summ[smiles]['pickle_path']}")['conformers']
    print(en, n_conformers, ' -------------------')
    for conf in conformers:
        mol1 = conf['rd_mol']
        canonical_noiso = Chem.MolToSmiles(Chem.RemoveHs(mol1), canonical=True, isomericSmiles=False)
        canonical_iso = Chem.MolToSmiles(Chem.RemoveHs(mol1), canonical=True, isomericSmiles=True)
        canonical_chiral = canonical_iso.replace('/', '')
        canonical_chiral = canonical_chiral.replace("\\", '')
        corrected_chiral = corrected_smiles.replace('/', '')
        corrected_chiral = corrected_chiral.replace("\\", '')
        if canonical_iso != corrected_smiles:
            print(f"iso Canonical SMILES do not match:  canonical_iso, canonical_noiso, smi, smi_corr\n")
            print(f"{'canonical_iso:':<20} {canonical_iso}")
            print(f"{'canonical_noiso:':<20} {canonical_noiso}")
            print(f"{'canonical_chiral:':<20} {canonical_chiral}")
            print(f"{'smiles:':<20} {smiles}")
            print(f"{'corrected_smiles:':<20} {corrected_smiles}")
            bad_smi_iso.append((smiles,n_conformers,corrected_smiles))
            break
        # if canonical_noiso != corrected_smiles:
        #     print(f"non iso Canonical SMILES do not match:  canonical_noiso, canonical_noiso, smi, smi_corr\n")
        #     print(f"{'canonical_iso:':<20} {canonical_iso}")
        #     print(f"{'canonical_noiso:':<20} {canonical_noiso}")
        #     print(f"{'canonical_chiral:':<20} {canonical_chiral}")
        #     print(f"{'smiles:':<20} {smiles}")
        #     print(f"{'corrected_smiles:':<20} {corrected_smiles}")
        #     bad_smi_noiso.append((smiles,n_conformers,corrected_smiles))
            # break
        # if canonical_iso != corrected_smiles and canonical_chiral == corrected_chiral:
        #     print(f"non tat Canonical SMILES do not match:  canonical_noiso, canonical_noiso, smi, smi_corr\n")
        #     print(f"{'canonical_iso:':<20} {canonical_iso}")
        #     print(f"{'canonical_noiso:':<20} {canonical_noiso}")
        #     print(f"{'canonical_chiral:':<20} {canonical_chiral}")
        #     print(f"{'smiles:':<20} {smiles}")
        #     print(f"{'corrected_smiles:':<20} {corrected_smiles}")
        #     print(f"{'corrected_chiral:':<20} {corrected_chiral}")
        #     bad_smi_chi.append((smiles,n_conformers,corrected_smiles))
        # break
    # if en == 10:
    #     break

In [None]:
len(set(bad_smi_iso)), len(set(bad_smi_noiso)), len(set(bad_smi_chi))

In [None]:
draw_molecule('F[C@@H](Cl)Br')

In [None]:
draw_molecule('F[C@H](Cl)Br')

In [None]:
Chem.AddHs(Chem.MolFromSmiles('OCC'))

In [None]:
draw_molecule('F/C=C/F')

In [None]:
draw_molecule('F/C=C\F')

In [None]:
for en, (smiles,n_conformers,corrected_smiles) in enumerate(bad_smi):
    pickles = [load_pkl(p) for p in find_pickle_files("/mnt/sxtn2/chem/GEOM_data/GEOM_data_MCF/data/processed_drugs/test_1000", smiles)]
    print(en,' -------------------')
    bads = 0
    for pic in pickles:
        mol1 = pic['mol']
        canonical = Chem.MolToSmiles(Chem.RemoveHs(mol1), canonical=True, isomericSmiles=True)
        if canonical != corrected_smiles:
            print(f"Canonical SMILES do not match: canonical, smi, smi_corr\n{canonical}")
            print(f"{smiles}")
            print(f"{corrected_smiles}")
            bad_smi.append((smiles,n_conformers,corrected_smiles))
            bads += 1
    print(f"Total bad SMILES: {bads} out of {n_conformers}")
    if en ==10:
        break

In [None]:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors

mol1 = Chem.MolFromSmiles("CCOC(=O)[C@@H](C(=O)OCC)[C@@H](C(=O)OCC)N1CCOCC1")
mol2 = Chem.MolFromSmiles("CCOC(=O)[C@H](C(=O)OCC)[C@@H](C(=O)OCC)N1CCOCC1")

# Compare InChIKeys
inchi1 = Chem.MolToInchiKey(mol1)
inchi2 = Chem.MolToInchiKey(mol2)

print("Same InChIKey?", inchi1 == inchi2)

In [None]:
mol

In [None]:
torsional = load_pkl(os.path.join(base_path, "GEOM_data_torsional_diff/DRUGS/test_mols.pkl"))

In [None]:
for mol in torsional["CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1"]:
    print(Chem.MolToSmiles(Chem.RemoveHs(mol), canonical=True))

In [None]:
torsional['CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1'][13]

In [None]:
torsional['CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1'][14]

In [None]:
GetBestRMS(torsional['CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1'][13], torsional['CC(C)(C)c1cc[n+](-c2[nH]c(=O)sc2/C=N/NC([S-])=Nc2ccccc2)cc1'][14])