In [1]:
import torch
from transformers import OPTForCausalLM, AutoModelForCausalLM
from chemlactica.utils.utils import get_tokenizer
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from rdkit.Chem.rdMolDescriptors import CalcTPSA

In [2]:
tokenizer = get_tokenizer("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Process 1594984 created a tokenizer


In [3]:
class LinearFloat32(torch.nn.Linear):
    def forward(self, _input) -> torch.Tensor:
        return super().forward(_input).to(torch.float32)


def cast_lm_head_to_fp32_init(func):
    def inner_func(self, config, *args, **kwargs):
        func(self, config, *args, **kwargs)
        self.lm_head = LinearFloat32(
            config.word_embed_proj_dim, config.vocab_size, bias=False
        )

    return inner_func

In [4]:
model_path = "/auto/home/menuab/code/checkpoints/9954e52e400b43d18d3a40f6/125m_122k_9954"
# model_path = "/auto/home/menuab/code/checkpoints/26d322857a184fcbafda5d4a/125m_118k_26d3/"
# OPTForCausalLM.__init__ = cast_lm_head_to_fp32_init(OPTForCausalLM.__init__)
model = OPTForCausalLM.from_pretrained(model_path).eval()
# model = 

In [5]:
prompt = "</s>[TPSA]0.00[/TPSA]"
prompt = tokenizer(prompt, return_tensors="pt").to(model.device)

In [88]:
out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=20, max_length=300)

In [89]:
out = tokenizer.batch_decode(out)[0]

In [90]:
out.find("[END_SMILES]")

-1

In [65]:
tpsa_score = AllChem.CalcTPSA(Chem.MolFromSmiles("CC(C)(C)OC(=O)N1CCCC1C(=O)N1CCCC1C(=O)NC(Cc1ccc(O)cc1)C(=O)O"))

In [74]:
CalcTPSA(Chem.MolFromSmiles("CC(C)(C)OC(=O)N1CCCC1C(=O)N1CCCC1C(=O)NC(Cc1ccc(O)cc1)C(=O)O")), tpsa_score

(136.48, 136.48)

In [93]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for i in np.arange(0,100, 1):
    prompt = f"</s>[TPSA]{i:.2f}[/TPSA]"
    len_prompt = len(prompt)
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=20, max_length=300)
    out = tokenizer.batch_decode(out)[0]
    try:
        if out.find("[END_SMILES]")!=-1:
            smiles = out[out.find("[START_SMILES]") + len("[START_SMILES]"):out.find("[END_SMILES]")]
            tpsa_score = AllChem.CalcTPSA(Chem.MolFromSmiles(smiles))
            diff = abs(i - tpsa_score)
            print("GT:", i, "Gen:", tpsa_score, "diff:", round(diff,2), smiles, out )
            ground_truths.append(i)
            gens.append(tpsa_score)
            diffs.append(diff)
        else:
            print(f"GT: {i} {out}")
    except:
        print(f"GT: {i} {out}")
        invalids += 1
        pass

GT: 0 </s>[TPSA]0.00[/TPSA][NUMALIPHATICHETEROCYCLES]0[/NUMALIPHATICHETEROCYCLES][NUMSATURATEDCARBOCYCLES]0[/NUMSATURATEDCARBOCYCLES][QED]0.59[/QED][NUMHACCEPTORS]0[/NUMHACCEPTORS][NOCOUNT]0[/NOCOUNT][WEIGHT]216.19[/WEIGHT][NUMHDONORS]0[/NUMHDONORS][SIMILAR]CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
GT: 1 Gen: 3.24 diff: 2.24 CCN(CC)C(C)c1ccc(-c2ccc(F)cc2)cc1 </s>[TPSA]1.00[/TPSA][NOCOUNT]1[/NOCOUNT][NUMSATURATEDCARBOCYCLES]0[/NUMSATURATEDCARBOCYCLES][QED]0.65[/QED][NHOHCOUNT]0[/NHOHCOUNT][NUMHACCEPTORS]1[/NUMH

In [7]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for i in np.arange(30,70, .1):
    i = round(i,2)
    prompt = f"</s>[TPSA]{i:.2f}[/TPSA]"
    len_prompt = len(prompt)
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=20, max_length=300)
    out = tokenizer.batch_decode(out)[0]
    try:
        if out.find("[END_SMILES]")!=-1:
            smiles = out[out.find("[START_SMILES]") + len("[START_SMILES]"):out.find("[END_SMILES]")]
            tpsa_score = AllChem.CalcTPSA(Chem.MolFromSmiles(smiles))
            diff = abs(i - tpsa_score)
            print("GT:", i, "Gen:", tpsa_score, "diff:", round(diff,2), smiles, out )
            ground_truths.append(i)
            gens.append(tpsa_score)
            diffs.append(diff)
        else:
            print(f"GT: {i} {out}")
    except:
        print(f"GT: {i} {out}")
        invalids += 1
        pass

GT: 30.0 Gen: 136.48 diff: 106.48 CC(C)(C)OC(=O)N1CCCC1C(=O)N1CCCC1C(=O)NC(Cc1ccc(O)cc1)C(=O)O </s>[TPSA]30.00[/TPSA][NUMALIPHATICCARBOCYCLES]0[/NUMALIPHATICCARBOCYCLES][NOCOUNT]3[/NOCOUNT][QED]0.78[/QED][FRACTIONCSP3]0.25[/FRACTIONCSP3][NUMALIPHATICHETEROCYCLES]0[/NUMALIPHATICHETEROCYCLES][NUMSATURATEDCARBOCYCLES]0[/NUMSATURATEDCARBOCYCLES][NUMHACCEPTORS]3[/NUMHACCEPTORS][NUMHETEROATOMS]3[/NUMHETEROATOMS][WEIGHT]228.14[/WEIGHT][CLOGP]2.80[/CLOGP][HEAVYATOMCOUNT]17[/HEAVYATOMCOUNT][NUMHDONORS]0[/NUMHDONORS][RINGCOUNT]2[/RINGCOUNT][NUMSATURATEDHETEROCYCLES]0[/NUMSATURATEDHETEROCYCLES][NUMAROMATICHETEROCYCLES]1[/NUMAROMATICHETEROCYCLES][NUMAROMATICCARBOCYCLES]1[/NUMAROMATICCARBOCYCLES][NHOHCOUNT]0[/NHOHCOUNT][SAS]2.19[/SAS][NUMALIPHATICRINGS]0[/NUMALIPHATICRINGS][NUMAROMATICRINGS]2[/NUMAROMATICRINGS][NUMSATURATEDRINGS]0[/NUMSATURATEDRINGS][NUMROTATABLEBONDS]4[/NUMROTATABLEBONDS]</s>[START_SMILES]CC(C)(C)OC(=O)N1CCCC1C(=O)N1CCCC1C(=O)NC(Cc1ccc(O)cc1)C(=O)O[END_SMILES]
GT: 30.1 Gen: 136.48

KeyboardInterrupt: 