In [8]:
import pandas as pd

df = pd.read_csv('/hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed/train_data_40k.csv')
df = df[:1000]
df.to_csv('/hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed/train_demo.csv',index=False)


In [48]:
df = pd.read_csv('/hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed/val_data_5k.csv')
df = df[:125]
df.to_csv('/hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed/val_demo.csv',index=False)

In [15]:
import yaml
from easydict import EasyDict

with open('/hpc2hdd/home/yli106/smiles2mol/config/qm9_default.yml', 'r') as file:
    config = yaml.safe_load(file)

config = EasyDict(config)

In [4]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import argparse
from easydict import EasyDict
from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("/hpc2hdd/home/yli106/smiles_3d/llama2_13b", trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token # Use the EOS token to pad shorter sequences
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

max_len = 0
for i in tqdm(range(len(df))):
    input_text = df['mol_block'][i]
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    if len(input_ids[0]) > max_len:
        max_len = len(input_ids[0])
max_len   

100%|██████████| 24068/24068 [01:15<00:00, 317.69it/s]


1739

In [6]:
smiles = df['canonicalize_smiles'][10]
system_prompt = 'Below is a SMILES of a molecule, generate its 3D structure. The molecule has 15 atoms and 15 bonds.'
inst = '<s>[INST] <<SYS>>\n' + system_prompt + '\n<</SYS>>\n\n' + smiles + ' [/INST] '+' </s>'
input_ids = tokenizer.encode(inst, return_tensors='pt')
len(input_ids[0])

105

In [2]:
def process_inst(smiles, num_atom, num_bond):
    system_prompt = f'Below is a SMILES of a molecule, generate its 3D structure. The molecule has {num_atom} atoms and {num_bond} bonds.'
    inst = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n' + system_prompt + '<|eot_id|><|start_header_id|>user<|end_header_id|>\n' + smiles + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
    
    return inst

In [7]:
text_col = []
for _, row in df.iterrows():
    num_atom = row['num_atom']
    num_bond = row['num_bond']
    smiles = row['canonicalize_smiles'] 
    inst = process_inst(smiles, num_atom, num_bond)
    model_answer = row['mol_block']
    text = inst + model_answer + '<|end_of_text|>'
    text_col.append(text)

df.loc[:, 'text'] = text_col

In [8]:
df

Unnamed: 0,smiles,canonicalize_smiles,num_atom,num_bond,mol_block,text
0,C[C@@]1(N2CC2)C[C@H]1O,[H]O[C@]1([H])C([H])([H])[C@]1(N1C([H])([H])C1...,19,20,\n RDKit 3D\n\n 19 20 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
1,C[C@@]1(N2CC2)C[C@H]1O,[H]O[C@]1([H])C([H])([H])[C@]1(N1C([H])([H])C1...,19,20,\n RDKit 3D\n\n 19 20 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
2,C[C@@]1(N2CC2)C[C@H]1O,[H]O[C@]1([H])C([H])([H])[C@]1(N1C([H])([H])C1...,19,20,\n RDKit 3D\n\n 19 20 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
3,C[C@@]1(N2CC2)C[C@H]1O,[H]O[C@]1([H])C([H])([H])[C@]1(N1C([H])([H])C1...,19,20,\n RDKit 3D\n\n 19 20 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
4,C[C@@]1(N2CC2)C[C@H]1O,[H]O[C@]1([H])C([H])([H])[C@]1(N1C([H])([H])C1...,19,20,\n RDKit 3D\n\n 19 20 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
...,...,...,...,...,...,...
199995,N[C@H]1COCOC1=O,[H]N([H])[C@]1([H])C(=O)OC([H])([H])OC1([H])[H],15,15,\n RDKit 3D\n\n 15 15 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
199996,N[C@H]1COCOC1=O,[H]N([H])[C@]1([H])C(=O)OC([H])([H])OC1([H])[H],15,15,\n RDKit 3D\n\n 15 15 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
199997,N[C@H]1COCOC1=O,[H]N([H])[C@]1([H])C(=O)OC([H])([H])OC1([H])[H],15,15,\n RDKit 3D\n\n 15 15 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...
199998,N[C@H]1COCOC1=O,[H]N([H])[C@]1([H])C(=O)OC([H])([H])OC1([H])[H],15,15,\n RDKit 3D\n\n 15 15 0 0 0 0...,<|begin_of_text|><|start_header_id|>system<|en...


In [5]:
df.to_csv('/hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed/train_data_40k_llama3.csv', index=False)

# inference

In [171]:
import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    HfArgumentParser,
    TrainingArguments,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import re
from tqdm import tqdm
import pandas as pd
import pickle
from conf3d import dataset
import yaml
from easydict import EasyDict
from torch.cuda import empty_cache
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDetermineBonds
from rdkit.Chem.rdmolops import RemoveHs
from rdkit.Chem import rdMolAlign as MA
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
from rdkit import RDLogger

with open('/hpc2hdd/home/yli106/smiles2mol/config/qm9_default.yml', 'r') as f:
    config = yaml.safe_load(f)
config = EasyDict(config)
config.training_arguments.learning_rate = float(config.training_arguments.learning_rate)

In [2]:
# load the tokenizer
model_path = os.path.join(config.model.base_path, '%s_%s' % (config.model.type, config.model.size))
peft_path = config.model.peft_path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token # Use the EOS token to pad shorter sequences
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

# model set up
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map= "auto"
)
model = PeftModel.from_pretrained(model, peft_path)
model = model.merge_and_unload()

# load the dataset
load_path = os.path.join(config.data.base_path, '%s_processed' % config.data.dataset)
print('loading data from %s' % load_path)
with open(os.path.join(load_path, config.data.test_set), 'rb') as f:
    test_data = pickle.load(f)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



loading data from /hpc2hdd/home/yli106/smiles2mol/GEOM/qm9_processed


In [12]:
def get_atom(mol):
    atom_num = mol.GetNumAtoms() 
    bond_num = mol.GetNumBonds()
    mol_block = Chem.MolToMolBlock(mol).split('\n')
    order_list = mol_block[:4]
    for i in range(atom_num):
        order_list.append(mol_block[i+4][31:])
    return order_list

In [33]:
def get_bond(mol):
    atom_num = mol.GetNumAtoms() 
    bond_num = mol.GetNumBonds()
    mol_block = Chem.MolToMolBlock(mol).split('\n')
    order_list=mol_block[4+atom_num:]
    return order_list

In [39]:
def get_diff_bond(mol_list):
    bond_list = [list(t) for t in set(tuple(get_bond(mol)) for mol in mol_list)]
    return bond_list
        

In [172]:
bond_list = get_diff_bond(test_data[1][3])
atom_list = get_atom(test_data[1][3][0])

In [89]:
num_conf = len(test_data[0][3])
print(num_conf)
input_text = dataset.process_inst(test_data[0][0], test_data[0][1], test_data[0][2])
input_text += ('\n').join(atom_list[:4])
print(input_text)
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = input_ids.to('cuda')

61
<s>[INST] <<SYS>>
Below is a SMILES of a molecule, generate its 3D structure. The molecule has 19 atoms and 18 bonds.
<</SYS>>

[H]C#CC(=O)[C@]([H])(O[H])C([H])([H])C([H])([H])C([H])([H])[H] [/INST] 
     RDKit          3D

 19 18  0  0  0  0  0  0  0  0999 V2000


In [None]:
n = 1
bond_list = get_diff_bond(test_data[n][3])
atom_list = get_atom(test_data[n][3][0])

num_conf = len(test_data[n][3])
print(num_conf)
input_text = dataset.process_inst(test_data[n][0], test_data[n][1], test_data[n][2])
input_text += ('\n').join(atom_list[:4])
print(input_text)

time = config.inference.multiplier*num_conf
gen = 0
raw_generated_texts = []
pbar = tqdm(total=time, desc="Generating")

while True:

    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    input_ids = input_ids.to('cuda')


    for i in range(test_data[n][1]): 
        class CustomStoppingCriteria(StoppingCriteria):
            def __init__(self):
                self.stopped_by_second_condition = False
            def __call__(self, input_ids, scores):
                indices = (input_ids[0] == 29889).nonzero(as_tuple=True)[0]
                if len(indices) >= 5+3*i:
                    third_index = indices[4+3*i]
                    num_elements_after_third = input_ids[0].numel() - third_index.item() - 1
                    if num_elements_after_third >=4:
                        return True      
                    else:
                        return False
                elif (input_ids[0] == 13).sum().item() >= test_data[n][1]+8:
                    self.stopped_by_second_condition = True  # Set the flag
                    return True
                else:
                    return False
        stopping_criteria = CustomStoppingCriteria()
        output_sequences = model.generate(
            input_ids=input_ids, 
            max_length=config.inference.max_length, 
            do_sample=config.inference.do_sample, 
            top_k=config.inference.top_k, 
            top_p=config.inference.top_p, 
            temperature=config.inference.temperature, 
            eos_token_id=tokenizer.eos_token_id, 
            stopping_criteria=[stopping_criteria], 
            num_return_sequences=1
        )   
        add_text = atom_list[4+i]+'\n'
        add_ids = tokenizer.encode(add_text, return_tensors='pt')
        add_ids = add_ids.to('cuda')
        input_ids = torch.cat((output_sequences, add_ids), dim=1)

    if not stopping_criteria.stopped_by_second_condition:
        raw_generated_texts.append(tokenizer.decode(input_ids[0], skip_special_tokens=True))
        gen+=1
        pbar.update(1)
    
    if gen>=time:
        pbar.close()
        break


In [204]:
bond_text = ('\n').join(bond_list[0][:])
generated_texts = [x + bond_text for x in raw_generated_texts]
print(generated_texts[0])

[INST] <<SYS>>
Below is a SMILES of a molecule, generate its 3D structure. The molecule has 22 atoms and 21 bonds.
<</SYS>>

[H]O/N=C(\C([H])([H])[H])C([H])([H])[C@]([H])(O[H])C([H])([H])C([H])([H])[H] [/INST] 
     RDKit          3D

 22 21  0  0  0  0  0  0  0  0999 V2000
   -2.7907   -0.9384   -0.1240 C   0  0  0  0  0  0  0  0  0  0  0  0
   -3.5563   -0.2106   -0.2462 C   0  0  0  0  0  0  0  0  0  0  0  0
   -4.2983    0.4780   -0.3537 H   0  0  0  0  0  0  0  0  0  0  0  0
   -3.0004   -0.4418   -1.1543 H   0  0  0  0  0  0  0  0  0  0  0  0
   -2.2697   -0.2116    0.9296 C   0  0  1  0  0  0  0  0  0  0  0  0
   -2.5876    1.2452    0.7701 O   0  0  0  0  0  0  0  0  0  0  0  0
   -2.0783    1.5805    0.0318 H   0  0  0  0  0  0  0  0  0  0  0  0
   -2.7921   -0.6282    1.7787 H   0  0  0  0  0  0  0  0  0  0  0  0
   -0.8327   -0.4341    0.7367 C   0  0  0  0  0  0  0  0  0  0  0  0
   -0.4424    0.3658   -0.3921 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.8140    0.9704   -0

In [205]:
gen_mol_list = []
for generated_text in generated_texts:
    mol_block_text = dataset.get_mol_block(generated_text, test_data[n][0], test_data[n][1], test_data[n][2])
    with open('test.mol', 'w') as f:
        f.write(mol_block_text)
    try:
        gen_mol = Chem.MolFromMolFile('/hpc2hdd/home/yli106/smiles2mol/test.mol')
        gen_mol = RemoveHs(gen_mol)
        gen_mol_list.append(gen_mol)
    except:
        pass

In [206]:
len(gen_mol_list)

10

In [209]:
utils.get_cov_mat_p(gen_mol_list, test_data[n][3], threshold=0.5)

(60.0, 0.9875314359301098)

In [217]:
from conf3d import dataset, utils
rmsd_list = []  
for i in range(len(test_data[n][3])):
    rmsd = utils.GetBestRMSD(gen_mol_list[7], test_data[n][3][i])
    rmsd_list.append(rmsd)
min(rmsd_list)
    

1.312874940785297

In [39]:
failed = 0
valid = 0
for gen_mol in generated_mol:
    min_rmsd = 1000
    for raw_ref_mol in test_data[0][3]:
        ref_mol = RemoveHs(raw_ref_mol) 
        try:
            rmsd = MA.GetBestRMS(gen_mol, ref_mol)
            if rmsd<min_rmsd:
                min_rmsd = rmsd
        except:
            pass
    
    if min_rmsd<=0.5:
        valid+=1
    elif min_rmsd==1000:
        failed+=1
        # print(Chem.MolToSmiles(ref_mol))
        # print(Chem.MolToMolBlock(ref_mol))

In [1]:
import pickle

with open('/hpc2hdd/home/yli106/smiles2mol/GEOM/generated/inference_llama2_7b_chat.pkl', 'rb') as f:
    test_data = pickle.load(f)