In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import rdkit
import re
from rdkit import Chem
from posebusters import PoseBusters
import json
import torch
from tqdm import tqdm
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
# This enables inline rendering of molecules
IPythonConsole.ipython_useSVG=True 

In [7]:
!export PYTHONPATH=$PYTHONPATH:/auto/home/menuab/code/3DMolGen

In [8]:
import sys
import os

# Get the absolute path to the project root
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))  # Assuming notebook is one level deep inside the project

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from molgen3D.utils.data_processing_utils import decode_cartesian_raw

ModuleNotFoundError: No module named 'molgen3D'

In [9]:
cart_1x_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-1B_conformers/c037e75255bc41c19c716939/step-4500"
cart_2x_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-1B_conformers/d267db61f57d4b428baa604a/step-9000"
cart_4x_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-1B_conformers/3408e9758572478c80393771/step-18000"
cart_6e_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-1B_conformers/301b8328481243c6aa8d8003/step-27000"
cart_8e_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-1B_conformers/c13311b27056459eaccf5877/step-36000"
m100_100 = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-27M_conformers/afebcc510dec403f9532dff6/step-42600"
m100_120p_path = "/nfs/h100/raid/chem/checkpoints/hf/yerevann/Llama-3.2-100M_conformers/c29bd453f4ff497d8c99c8f7/step-30000"


tokenizer  = AutoTokenizer.from_pretrained("/auto/home/menuab/code/YNNtitan/torchtitan/tokenizers/Llama-3.2-chem-1B-v1", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(m100_120p_path, 
                                             torch_dtype=torch.float32).to("cuda:1")
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.dtype, model.device

(torch.float32, device(type='cuda', index=1))

In [10]:
import transformers
top_p_sampling_config = transformers.GenerationConfig (
  do_sample=True,
  temperature=0.8,
  top_p=0.9
)


In [18]:
canonical_smiles = '[H][C]([H])=[C]([H])[C]([H])([H])[O][C](=[O])[C]([H])([H])[C]([H])([H])[C](=[O])[N]([H])[c]1[c]([H])[c]([H])[c]([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]12'
pad_len = 0
prompt = f"[SMILES]{canonical_smiles}[/SMILES]"
prompt = tokenizer(prompt, 
                   padding='max_length', 
                   max_length=len(tokenizer(prompt)["input_ids"])+pad_len, 
                   return_tensors="pt", 
                   add_special_tokens=True).to(model.device)
output = model.generate(input_ids=prompt["input_ids"], 
                        attention_mask=prompt["attention_mask"], 
                        max_new_tokens=2000, 
                        eos_token_id=128329, 
                        generation_config=top_p_sampling_config)
output = tokenizer.batch_decode(output)[0]
generated_conformer = output[output.find("[CONFORMER]")+len("[CONFORMER]"):output.find("[/CONFORMER]")]
generated_smiles = re.sub(r'<[^>]*>', '', generated_conformer) 
print(output)
print(canonical_smiles)
print(generated_smiles)

<|begin_of_text|>[SMILES][H][C]([H])=[C]([H])[C]([H])([H])[O][C](=[O])[C]([H])([H])[C]([H])([H])[C](=[O])[N]([H])[c]1[c]([H])[c]([H])[c]([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]12[/SMILES][CONFORMER][H<-3.4386,-1.7210,2.7221>][C<-3.2844,-1.6904,1.6540>]([H<-3.9768,-2.2526,1.0446>])=[C<-2.3256,-0.9643,1.1130>]([H<-1.6615,-0.3755,1.7304>])[C<-2.0398,-0.8939,-0.3566>]([H<-2.6273,-1.6212,-0.9206>])([H<-0.9747,-1.0575,-0.5448>])[O<-2.4109,0.3883,-0.8830>][C<-3.6528,0.8059,-0.6420>](=[O<-4.4766,0.2068,-0.0088>])[C<-3.8874,2.1477,-1.3157>]([H<-3.4205,2.1118,-2.2959>])([H<-4.9655,2.3122,-1.4156>])[C<-3.2508,3.2557,-0.4629>]([H<-3.8272,3.3847,0.4453>])([H<-3.2398,4.1897,-1.0307>])[C<-1.8350,2.8642,-0.0636>](=[O<-1.1587,3.5253,0.6875>])[N<-1.4476,1.6842,-0.6270>]([H<-2.1489,1.2179,-1.1938>])[c<-0.3155,0.9195,-0.3070>]1[c<0.8377,1.4874,0.1927>]([H<0.8690,2.5432,0.4100>])[c<1.9694,0.7016,0.4472>]([H<2.8538,1.1747,0.8530>])[c<1.9469,-0.6365,0.1855>]([H<2.8059,-1.2632,0.4102>])[c<0.8033,-1.2407,-

In [None]:
# file = open('drugs_test_mols_inference.jsonl','w') 
ref, gen, incs = [], [], []
inc = 0
import re
for en, mol_dict in enumerate(tqdm(test_mols)):
    canonical_smiles = mol_dict["canonical_smiles"]
    canonical_smiles = '[H][O][C]1=[C]([C](=[N][C]([H])([H])[C]([H])([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]2[H])[C]([H])([H])[H])[C](=[O])[S][C]1([H])[H]'

    geom_smiles = mol_dict["geom_smiles"]
    num_generations = mol_dict["num_confs"] * 2
    print(f"mol num: {en+1} generating {num_generations} conformers for {geom_smiles}")    
    generations = []
    prompt = f"[SMILES]{canonical_smiles}[/SMILES]"
    prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(model.device).input_ids
    
    # output = model.generate(prompt, max_new_tokens=2000, eos_token_id=128329, do_sample=False)
    output = model.generate(prompt, max_length=2000, eos_token_id=128329, 
                        # num_beams=10,
                        # num_beam_groups=2,
                        # diversity_penalty=1.1,
                        num_return_sequences=2, 
                        do_sample=True,
                        top_p=0.90,
                        temperature=0.8,
                        # top_k=3
                        )
    print(f"len prompt toks: {len(prompt[0])}, len gen toks: {len(output[0])-len(prompt[0])}")
    output = tokenizer.batch_decode(output)
    # print("raw output: ", output)
    # print("canonical_smiles: ", canonical_smiles)
    # display(Chem.MolFromSmiles(canonical_smiles))
    for out in output:
        generated_conformer = out[out.find("[CONFORMER]")+len("[CONFORMER]"):out.find("[/CONFORMER]")]
        # print(generated_conformer)
        generated_smiles = re.sub(r'<[^>]*>', '', generated_conformer)
        if generated_smiles == canonical_smiles:
            ref.append(geom_smiles)
            gen.append(generated_smiles)
            print(f"{canonical_smiles=}")
            print(f"{generated_smiles=}")
            print(f"{out=}")
            sample = {
                "geom_smiles": geom_smiles,
                "generated_conformer": generated_conformer
            }
            file.write(f"{json.dumps(sample)}\n")
        else:
            print("smiles didn't match for ")
            print(f"{canonical_smiles=}")
            print(f"{generated_smiles=}")
            print(out)
            incs.append(out)
            inc += 1
        # print("generated_smiles: ", generated_smiles)
        # display(Chem.MolFromSmiles(generated_smiles))
    print('----------------------')
    file.close()
    break
    # if en==100:
    #     break
    

  0%|          | 0/1000 [00:00<?, ?it/s]

mol num: 1 generating 138 conformers for C#CCNC(=O)C1=C[C@@H](c2ccc(Br)cc2)C[C@@H](OCc2ccc(CO)cc2)O1


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
  0%|          | 0/1000 [00:11<?, ?it/s]

len prompt toks: 88, len gen toks: 602
canonical_smiles='[H][O][C]1=[C]([C](=[N][C]([H])([H])[C]([H])([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]2[H])[C]([H])([H])[H])[C](=[O])[S][C]1([H])[H]'
generated_smiles='[H][O][C]1=[C]([C](=[N][C]([H])([H])[C]([H])([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]2[H])[C]([H])([H])[H])[C](=[O])[S][C]1([H])[H]'
out='<|begin_of_text|>[SMILES][H][O][C]1=[C]([C](=[N][C]([H])([H])[C]([H])([H])[c]2[c]([H])[c]([H])[c]([H])[c]([H])[c]2[H])[C]([H])([H])[H])[C](=[O])[S][C]1([H])[H][/SMILES][CONFORMER][H<-2.1227,-2.4066,0.0171>][O<-3.0615,-2.0513,0.0322>][C<-2.9815,-0.7348,0.0014>]1=[C<-1.9039,0.0974,-0.0384>]([C<-0.5432,-0.4122,-0.0478>](=[N<0.4296,0.4126,-0.1234>][C<1.8135,0.0420,-0.1334>]([H<1.9969,-0.8204,-0.7864>])([H<2.1302,-0.2397,0.8801>])[C<2.6484,1.2332,-0.6214>]([H<2.4324,1.4031,-1.6782>])([H<2.3361,2.1208,-0.0685>])[c<4.1106,0.9633,-0.4145>]2[c<4.7994,0.1287,-1.2863>]([H<4.2861,-0.2955,-2.1375>])[c<6.1366,-0.1581,-1.0765>]([H<6.6606,-0.8065,-1.7634>]




In [6]:
m_gen = '[H<-0.0541,-1.9089,0.0001>][O<-0.9993,-2.2537,0.0001>][C<-1.7999,-1.2063,0.0001>]1=[C<-1.4564,0.1106,0.0001>]([C<-0.0788,0.5689,0.0001>](=[N<0.8417,-0.3277,0.0001>][C<2.2485,-0.0474,0.0001>]([H<2.5225,0.5473,0.8810>])([H<2.5319,0.5279,-0.8907>])[C<3.0100,-1.3801,0.0118>]([H<2.7208,-1.9349,0.9052>])([H<2.7059,-1.9615,-0.8595>])[c<4.4899,-1.1329,0.0004>]2[c<5.2423,-1.3378,1.1489>]([H<4.7600,-1.7082,2.0424>])[c<6.6015,-1.0774,1.1557>]([H<7.1753,-1.2451,2.0552>])[c<7.2245,-0.6015,0.0147>]([H<8.2852,-0.3985,0.0200>])[c<6.4818,-0.3901,-1.1344>]([H<6.9632,-0.0206,-2.0281>])[c<5.1237,-0.6519,-1.1403>]2[H<4.5471,-0.4887,-2.0396>])[C<0.1665,2.0521,-0.0001>]([H<1.2285,2.2721,-0.0001>])([H<-0.2910,2.4999,-0.8815>])[H<-0.2959,2.5005,0.8790>])[C<-2.6091,1.0094,0.0001>](=[O<-2.6401,2.2087,0.0001>])[S<-4.1199,0.0224,-0.0001>][C<-3.2319,-1.5780,-0.0001>]1([H<-3.5528,-2.1773,-0.8520>])[H<-3.4629,-2.1250,0.9149>]'