In [1]:
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch
from transformers import OPTForCausalLM, AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset
from datasets.iterable_dataset import IterableDataset
from transformers import OPTForCausalLM
from chemlactica.utils.model_utils import load_model
from chemlactica.utils.utils import get_tokenizer
import scipy
import numpy as np
import numpy
import gc

import random
from sklearn.metrics import root_mean_squared_error
from rdkit import Chem
from sklearn.metrics import roc_auc_score


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
random.seed(42)
numpy.random.seed(42)

In [3]:
chemlactica_tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66")
chemma_tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/GemmaTokenizer")
galactica_tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-125m")
len(chemlactica_tokenizer), len(chemma_tokenizer), len(galactica_tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(50066, 256000, 50000)

In [5]:
model_125m_18k_1f28 = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/6f2dbee5a74548b9ad509462"
chemlactcia_model = AutoModelForCausalLM.from_pretrained(
    model_125m_18k_1f28 + "/last").to("cuda:0").eval()
chemlactcia_model.device, chemlactcia_model.dtype

(device(type='cuda', index=0), torch.float32)

In [6]:
data = [
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])(C([H])([H])[H])C([H])([H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])C3(C(C3([H])[H])([H])[H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C3(C(C3([H])[H])([H])[H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C3(C(C(C(C3([H])[H])([H])[H])([H])[H])([H])[H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])O[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)N([H])C([H])([H])C([H])([H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)OC([H])([H])[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)[C@@]3([C@](C3([H])[H])([H])F)[H])[H])[H])Cl)[H]',
    '[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)[C@@]3([C@](C3([H])[H])([H])Cl)[H])[H])[H])Cl)[H]'
]

In [7]:
mol = Chem.MolFromSmiles('[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]')
std_smiles1 = Chem.MolToSmiles(mol)
std_smiles2 = Chem.MolToSmiles(mol, doRandom=True)
kek_smiles = Chem.MolToSmiles(mol, kekuleSmiles=True)
can_smiles = Chem.MolToSmiles(mol, canonical=True)
iso_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
kek_iso_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, kekuleSmiles=True)
(std_smiles1, std_smiles2, can_smiles, iso_smiles, kek_smiles, kek_iso_smiles)

('CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1',
 'c1cc(c(C(=O)Nc2cc(NC(C)=O)ncc2)c(c1)Cl)Cl',
 'CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1',
 'CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1',
 'CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1',
 'CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1')

In [20]:
for sample in data:
    mol = Chem.MolFromSmiles('[H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]')
    std_smiles1 = Chem.MolToSmiles(mol)
    kek_smiles = Chem.MolToSmiles(mol, kekuleSmiles=True)
    print(std_smiles1, kek_smiles, sample)

CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1 [H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])[H])[H])[H])Cl)[H]
CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1 [H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])(C([H])([H])[H])C([H])([H])[H])[H])[H])Cl)[H]
CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1 [H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C([H])([H])C3(C(C3([H])[H])([H])[H])[H])[H])[H])Cl)[H]
CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1 [H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C3(C(C3([H])[H])([H])[H])[H])[H])[H])Cl)[H]
CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1 [H]c1c(c(c(c(c1[H])Cl)C(=O)N([H])c2c(c(nc(c2[H])N([H])C(=O)C3(C(C(C(C3([H])[H])([H])[H])([H])[H])([H])[H])[H])[H])[H])Cl)[H]
CC(=O)Nc1cc(NC(=O)c2c(Cl)cccc2Cl)ccn1 CC(=O)NC1=

In [22]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for sample in data:
    # ground_truth = round(sample['activity'], 2)
    mol = Chem.MolFromSmiles(sample)
    # std_smiles1 = Chem.MolToSmiles(mol)
    mol = Chem.MolToSmiles(mol, kekuleSmiles=True)
    prompt = f"</s>[START_SMILES]{mol}[END_SMILES][PROPERTY]activity"
    len_prompt = len(prompt)
    prompt = chemlactica_tokenizer(prompt, return_tensors="pt").to(chemlactcia_model.device)
    out = chemlactcia_model.generate(prompt.input_ids, do_sample=False, eos_token_id=chemlactica_tokenizer.encode('[/PROPERTY]')[0], max_new_tokens=100)
    out = chemlactica_tokenizer.batch_decode(out)[0]
    print(out)

    gen = float(out[out.find("activity ") + len("activity "):out.find("[/PROPERTY]")])
    print(gen, kek_smiles)

</s>[START_SMILES]CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1[END_SMILES][PROPERTY]activity -1.64[/PROPERTY]
-1.64 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]CC(C)C(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1[END_SMILES][PROPERTY]activity -1.64[/PROPERTY]
-1.64 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]O=C(CC1CC1)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1[END_SMILES][PROPERTY]activity -1.64[/PROPERTY]
-1.64 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]O=C(NC1=CC=NC(NC(=O)C2CC2)=C1)C1=C(Cl)C=CC=C1Cl[END_SMILES][PROPERTY]activity -1.58[/PROPERTY]
-1.58 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]O=C(NC1=CC=NC(NC(=O)C2CCCC2)=C1)C1=C(Cl)C=CC=C1Cl[END_SMILES][PROPERTY]activity -1.41[/PROPERTY]
-1.41 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]O=C(CO)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1[END_SMILES][PROPERTY]activity -1.64[/PROPERTY]
-1.64 CC(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1
</s>[START_SMILES]CCNC(=O)

In [23]:
out="</s>[START_SMILES]CC(C)C(=O)NC1=CC(NC(=O)C2=C(Cl)C=CC=C2Cl)=CC=N1[END_SMILES][PROPERTY]activity -1.64[/PROPERTY]"

In [26]:
out[out.find("activity ") + len("activity "):out.find("[/PROPERTY]")]

'-1.64'