In [2]:
import torch
from transformers import OPTForCausalLM, AutoModelForCausalLM, AutoTokenizer
# from chemlactica.utils.utils import get_tokenizer
from rdkit.Chem import Descriptors

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import RDConfig, MACCSkeys, QED
import matplotlib.pyplot as plt
import numpy as np
import json
import pickle
from sklearn import metrics
from scipy.stats import spearmanr
from rdkit.Chem.rdMolDescriptors import CalcTPSA
from rdkit import Chem, DataStructs
import sys
import os
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd    
jsonObj = pd.read_json(path_or_buf="/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/valid_rdkit_computed_rel+form/950001_start.jsonl", lines=True)

In [4]:
smiles = list(jsonObj[jsonObj!=""].dropna().SMILES.values)

In [5]:
len(smiles)

10927

In [6]:
np.random.seed(42)
smiles100 = np.random.choice(smiles, size=100)

In [7]:
smiles100

array(['CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1',
       'CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1',
       'CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1',
       'CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2',
       'CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-]',
       'O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O',
       'CCC(=O)c1sc(NC2C3C4CCC(C4)C23)c(OC)c1N',
       'CN(C)S(=O)(=O)NCCCc1ncc[nH]1',
       'Cc1ccc(F)c(OCc2cc(F)cc(C#CCO)c2)c1',
       'Cc1cc(C)c(S(=O)(=O)NCC2CCCN2c2ccccc2)c(C)c1',
       'Cc1nn(C)c(Oc2ccc(Br)cc2[N+](=O)[O-])c1CO',
       'CN=C(NCc1nc(-c2ccccc2)cs1)N1CCSC(C)(C)C1',
       'CC(C)NS(=O)(=O)[N-]Cc1ccccc1.O=C(O)c1cccc2c1cnn2-c1ccc(F)cc1',
       'CCOc1ccc(-c2noc(Cn3c(=O)n(C(C)C)c(=O)c4ccccc43)n2)cc1OC',
       'O=C(CSc1nnnn1-c1ccc(Cl)cc1)Nc1cccc(Cl)c1N1CCCCC1',
       'Cc1ccccc1C1(O)CCN(C(=O)c2ccc(-c3ccc(F)cc3)o2)CC1',
       'CCc1ccc(NC(=O)NC(C)(C)CCC(=O)O)cc1S(N)(=O)=O',
       'CCCOc1ccc(C=NNC(=O)COc2ccccc2C)cc1OC',
       'CC(C)(C)OC(

In [8]:
pubchem_stats_file = open("/auto/home/menuab/code/ChemLacticaTestSuite/src/stats_data/pubchem_stats.pkl", 'rb')
pubchem_stats = pickle.load(pubchem_stats_file)
pubchem_stats_file.close()

In [32]:
tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66")
# tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/GemmaTokenizer")
len(tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


50066

In [10]:
model_125m_20k_9954 = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480"
model_2b_11k_d6e6 = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/d6e6a76e91814ad68d5fa264/checkpoint-11000"
model_2b_2k_699e = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/699e8c6078bb4461a73b39de/checkpoint-2000"
model_125m_18k_1f28 = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000/"
model_2b_12k_0717 = "/nfs/dgx/raid/chem/checkpoints/h100/google/gemma-2b/0717d445bcf44e31b2887892/checkpoint-12000"
model_2b_18k_0717 = "/nfs/dgx/raid/chem/checkpoints/h100/google/gemma-2b/0717d445bcf44e31b2887892/checkpoint-18000"

# model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
# model.device, model.dtype

In [30]:
model = AutoModelForCausalLM.from_pretrained(
    model_125m_18k_1f28, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to("cuda:1").eval()
model.device, model.dtype

(device(type='cuda', index=1), torch.bfloat16)

In [33]:
def calculate_tanim_sim(m, rel_m):
    m=Chem.MolFromSmiles(m)
    rel_m =Chem.MolFromSmiles(rel_m)
    fp = AllChem.GetMorganFingerprint(m, 2)
    rel_fp = AllChem.GetMorganFingerprint(rel_m, 2)
    return DataStructs.TanimotoSimilarity(fp, rel_fp)

In [39]:
def generate_plot(target_clean, generated_clean, diffs, test_name, rmse, mape, correlation, n_invalid_generations, n_total_gens, property_range, thickness):
    max_, min_, max_g = np.max(target_clean), np.min(target_clean), np.max(generated_clean)
    title = f'model_2b_18k_0717 {test_name} Greedy sampling\n'\
            f'{n_invalid_generations}/{n_total_gens} invalid SMILES\n'\
            f'rmse {rmse:.3f} mape {mape:.3f} corr: {correlation:.3f}\n'\

    fig, ax1 = plt.subplots()
    fig.set_figheight(6)
    fig.set_figwidth(8)
    fig.set_linewidth(4)
    if thickness != 0:    
        ax2 = ax1.twinx()
        stats = pubchem_stats[test_name.upper()]
    #     property_range = [1, 10]
        # stats_width = (property_range[1] - property_range[0]) / thickness
        ax2.bar([interval.mid for interval in stats.index], stats, width=thickness, alpha=0.3) 
    dist = max_ - min_
    margin = 0.05
    ax1.set_xlim([min_- margin*dist, max_ + margin*dist])
    ax1.scatter(target_clean, generated_clean, c='b')
    # ax1.vlines(nones, ymin=min_, ymax=max_, color='r', alpha=0.3)
    ax1.plot([min_, max_], [min_, max_], color='grey', linestyle='--', linewidth=2)
    ax1.plot(target_clean, np.convolve(np.pad(diffs, (2, 2), mode='edge'), np.ones(5)/5, mode='valid'), color='m', alpha=0.5)
    ax1.set_xlabel(f'Target {test_name}')
    ax1.set_ylabel(f'Generated {test_name}')
    # ax1.set_xlabel(f'Ground truth {test_name}')
    # ax1.set_ylabel(f'Predicted {test_name}')
    ax1.grid(True)
    plt.title(title)
    plt.tight_layout()
    fig.savefig(f'{test_name}_property.png', dpi=300, format="png")
    fig.clf()
    plt.close()

In [36]:
tic = {
        "QED": .01,
        "SAS": .1,
        "TPSA": 1,
        "WEIGHT": 12,
        "CLOGP": .25,
        }
properties = ["TPSA","WEIGHT","QED","SAS","CLOGP"]
for property in properties:
    ground_truths, gens, diffs = [],[],[]
    invalids = 0
    for s in smiles100:
        prompt = f"</s>[START_SMILES]{s}[END_SMILES][{property}]"
        len_prompt = len(prompt)
        prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=tokenizer.encode(f"[/{property}]")[0], max_new_tokens=300)
        out = tokenizer.batch_decode(out)[0]
        try:
            if out.find(f"[{property}]")!=-1:
                gen_score = float(out[out.find(f"[{property}]") + len(f"[{property}]"):out.find(f"[/{property}]")])
                if property == "TPSA":
                    gt_score = AllChem.CalcTPSA(Chem.MolFromSmiles(s))
                elif property == "QED":
                    gt_score = QED.qed(Chem.MolFromSmiles(s))
                elif property == "SAS":
                    gt_score = sascorer.calculateScore(Chem.MolFromSmiles(s))
                elif property == "WEIGHT":
                    gt_score = Descriptors.ExactMolWt(Chem.MolFromSmiles(s))
                elif property == "CLOGP":
                    gt_score = Descriptors.MolLogP(Chem.MolFromSmiles(s))
                diff = abs(gt_score - gen_score)
                print("GT:", round(gt_score,2), "Gen:", gen_score, "diff:", round(diff,2), s, out )
                ground_truths.append(gt_score)
                gens.append(gen_score)
                diffs.append(diff)
            else:
                print(f"GT: {gt_score} GEN: {gen_score} {out}")
        except:
            print(f"GT: {gt_score} {out}")
            invalids += 1
            raise

    combined = list(zip(ground_truths, gens, diffs))
    combined.sort(key=lambda x: x[0])
    ground_truths, gens, diffs = zip(*combined)
    rmse = metrics.mean_squared_error(ground_truths, gens, squared=False)
    mape = metrics.mean_absolute_percentage_error(ground_truths, gens)
    correlation = correlation, pvalue = spearmanr(ground_truths, gens)
    # tic = 900 if property in ["TPSA", "WEIGHT"] else 100
    generate_plot(ground_truths, gens, diffs, property, rmse, mape, correlation, invalids, len(ground_truths), (min(ground_truths), max(ground_truths)), tic[property])

GT: 140.0 Gen: 140.0 diff: 0.0 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][TPSA]140.00[/TPSA]
GT: 52.55 Gen: 52.55 diff: 0.0 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][TPSA]52.55[/TPSA]
GT: 96.53 Gen: 96.53 diff: 0.0 CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1 </s>[START_SMILES]CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1[END_SMILES][TPSA]96.53[/TPSA]
GT: 48.99 Gen: 48.99 diff: 0.0 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[START_SMILES]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2[END_SMILES][TPSA]48.99[/TPSA]
GT: 85.02 Gen: 85.02 diff: 0.0 CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-] </s>[START_SMILES]CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-][END_SMILES][TPSA]85.02[/TPSA]
GT: 133.25 Gen: 133.25 diff: 0.0 O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O </s>[START_SMILES]O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O[END



GT: 460.14 Gen: 460.14 diff: 0.0 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][WEIGHT]460.14[/WEIGHT]
GT: 389.2 Gen: 389.2 diff: 0.0 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][WEIGHT]389.20[/WEIGHT]
GT: 341.14 Gen: 341.14 diff: 0.0 CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1 </s>[START_SMILES]CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1[END_SMILES][WEIGHT]341.14[/WEIGHT]
GT: 299.11 Gen: 299.11 diff: 0.0 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[START_SMILES]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2[END_SMILES][WEIGHT]299.11[/WEIGHT]
GT: 247.11 Gen: 247.11 diff: 0.0 CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-] </s>[START_SMILES]CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-][END_SMILES][WEIGHT]247.11[/WEIGHT]
GT: 488.09 Gen: 488.09 diff: 0.0 O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O </s>[START_SMILES]O=C(NCC(O)CN1CCOCC1)Nc1



GT: 0.18 Gen: 0.18 diff: 0.0 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][QED]0.18[/QED]
GT: 0.56 Gen: 0.56 diff: 0.0 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][QED]0.56[/QED]
GT: 0.67 Gen: 0.67 diff: 0.0 CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1 </s>[START_SMILES]CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1[END_SMILES][QED]0.67[/QED]
GT: 0.87 Gen: 0.87 diff: 0.0 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[START_SMILES]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2[END_SMILES][QED]0.87[/QED]
GT: 0.64 Gen: 0.64 diff: 0.0 CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-] </s>[START_SMILES]CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-][END_SMILES][QED]0.64[/QED]
GT: 0.42 Gen: 0.42 diff: 0.0 O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O </s>[START_SMILES]O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O[END_SMILES][QED]0.42[/QED]
GT: 0.



GT: 2.33 Gen: 2.32 diff: 0.01 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][SAS]2.32[/SAS]
GT: 3.64 Gen: 3.62 diff: 0.02 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][SAS]3.62[/SAS]
GT: 3.79 Gen: 3.81 diff: 0.02 CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1 </s>[START_SMILES]CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1[END_SMILES][SAS]3.81[/SAS]
GT: 2.53 Gen: 2.53 diff: 0.0 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[START_SMILES]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2[END_SMILES][SAS]2.53[/SAS]
GT: 2.49 Gen: 2.5 diff: 0.01 CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-] </s>[START_SMILES]CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-][END_SMILES][SAS]2.50[/SAS]
GT: 3.17 Gen: 3.2 diff: 0.03 O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O </s>[START_SMILES]O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O[END_SMILES][SAS]3.20[/SAS]
GT:



GT: 3.84 Gen: 3.8 diff: 0.04 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][CLOGP]3.80[/CLOGP]
GT: 3.29 Gen: 3.27 diff: 0.02 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][CLOGP]3.27[/CLOGP]
GT: 0.99 Gen: 0.94 diff: 0.05 CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1 </s>[START_SMILES]CCOC(=O)C1C(C)NNC1S(=O)(=O)Nc1cccc(CC)c1[END_SMILES][CLOGP]0.94[/CLOGP]
GT: 2.9 Gen: 2.8 diff: 0.1 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[START_SMILES]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2[END_SMILES][CLOGP]2.80[/CLOGP]
GT: 2.11 Gen: 2.06 diff: 0.05 CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-] </s>[START_SMILES]CNc1cccc(Nc2cnn(C)c2)c1[N+](=O)[O-][END_SMILES][CLOGP]2.06[/CLOGP]
GT: 2.03 Gen: 2.0 diff: 0.03 O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O </s>[START_SMILES]O=C(NCC(O)CN1CCOCC1)Nc1snc(OCc2ccc(Cl)cc2F)c1C(=O)O[END_SMILES][



In [37]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for s in smiles100[:10]:
    for s2 in smiles100[10:20]:
        prompt = f"</s>[START_SMILES]{s}[END_SMILES][SIMILAR]{s2} "
        len_prompt = len(s2) + 1
        prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=tokenizer.encode("[/SIMILAR]")[0], max_new_tokens=300)
        out = tokenizer.batch_decode(out)[0]
        try:
            if out.find("[/SIMILAR]")!=-1:
                gen_score = round(float(out[out.find("[SIMILAR]") + len("[SIMILAR]") + len_prompt:out.find("[/SIMILAR]")]), 2)
                # tpsa_score = AllChem.CalcTPSA(Chem.MolFromSmiles(smiles))
                # gt_score = round(QED.qed(Chem.MolFromSmiles(s)), 2)
                # gt_score = round(sascorer.calculateScore(Chem.MolFromSmiles(s)), 2)
                gt_score = calculate_tanim_sim(s, s2)
                diff = abs(gt_score - gen_score)
                print("GT:", round(gt_score,2), "Gen:", gen_score, "diff:", round(diff,2), s, out )
                ground_truths.append(gt_score)
                gens.append(gen_score)
                diffs.append(diff)
            else:
                print(f"GT: {gt_score} GEN: {gen_score} {out}")
        except:
            print(f"GT: {gt_score} {out}")
            invalids += 1
            raise

GT: 0.19 Gen: 0.21 diff: 0.02 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][SIMILAR]Cc1nn(C)c(Oc2ccc(Br)cc2[N+](=O)[O-])c1CO 0.21[/SIMILAR]
GT: 0.12 Gen: 0.14 diff: 0.02 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][SIMILAR]CN=C(NCc1nc(-c2ccccc2)cs1)N1CCSC(C)(C)C1 0.14[/SIMILAR]
GT: 0.23 Gen: 0.24 diff: 0.01 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][SIMILAR]CC(C)NS(=O)(=O)[N-]Cc1ccccc1.O=C(O)c1cccc2c1cnn2-c1ccc(F)cc1 0.24[/SIMILAR]
GT: 0.16 Gen: 0.18 diff: 0.02 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[START_SMILES]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1[END_SMILES][SIMILAR]CCOc1ccc(-c2

GT: 0.16 Gen: 0.16 diff: 0.0 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][SIMILAR]CC(C)NS(=O)(=O)[N-]Cc1ccccc1.O=C(O)c1cccc2c1cnn2-c1ccc(F)cc1 0.16[/SIMILAR]
GT: 0.15 Gen: 0.16 diff: 0.01 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][SIMILAR]CCOc1ccc(-c2noc(Cn3c(=O)n(C(C)C)c(=O)c4ccccc43)n2)cc1OC 0.16[/SIMILAR]
GT: 0.15 Gen: 0.16 diff: 0.01 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][SIMILAR]O=C(CSc1nnnn1-c1ccc(Cl)cc1)Nc1cccc(Cl)c1N1CCCCC1 0.16[/SIMILAR]
GT: 0.16 Gen: 0.17 diff: 0.01 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1[END_SMILES][SIMILAR]Cc1ccccc1C1(O)CCN(C(=O)c2ccc(-c3ccc(F)cc3)o2)CC1 0.17[/SIMILAR]
GT: 0.14 Gen: 0.14 diff: 0.0 CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)C1 </s>[START_SMILES]CCNC(=NCC(C)Cc1cccs1)NC1CCN(c2ncccc2F)

In [38]:
combined = list(zip(ground_truths, gens, diffs))
combined.sort(key=lambda x: x[0])
ground_truths, gens, diffs = zip(*combined)
rmse = metrics.mean_squared_error(ground_truths, gens, squared=False)
mape = metrics.mean_absolute_percentage_error(ground_truths, gens)
correlation = correlation, pvalue = spearmanr(ground_truths, gens)
generate_plot(ground_truths, gens,diffs, 'Similarity', rmse, mape, correlation, invalids, len(ground_truths), (0,1), 0)



In [41]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for s in smiles100[:20]:
    for sim in np.arange(0.2,1.05,0.05):
        prompt = f"</s>[SIMILAR]{s} {sim:.2f}[/SIMILAR]"
        prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
        out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=tokenizer.encode("[END_SMILES]")[0],
        suppress_tokens=[2, 10, 44, 11, 45, 12, 46, 13, 47],
        # suppress_tokens=[2, 50002, 50035, 50003, 50036, 50000, 50033],
          repetition_penalty=1.01, renormalize_logits=True,max_length=2000)
        out = tokenizer.batch_decode(out)[0]
        if out.find("[END_SMILES]")!=-1:
            try:
                mol = out[out.find("[START_SMILES]") + len("[START_SMILES]"):out.find("[END_SMILES]")]
                # tpsa_score = AllChem.CalcTPSA(Chem.MolFromSmiles(smiles))
                # gt_score = round(QED.qed(Chem.MolFromSmiles(s)), 2)
                # gt_score = round(sascorer.calculateScore(Chem.MolFromSmiles(s)), 2)
                gen_score = calculate_tanim_sim(s, mol)
                diff = abs(sim - gen_score)
                print("GT:", round(sim,2), "Gen:", gen_score, "diff:", round(diff,2), s, out )
                ground_truths.append(sim)
                gens.append(gen_score)
                diffs.append(diff)
            except:
                print(f"GT: {sim} {out}")
                invalids += 1
        else:
            print(f"GT: {sim} {out}")
            invalids += 1
        # except:
        #     print(f"GT: {sim} {out}")
        #     invalids += 1
        #     raise

GT: 0.2 Gen: 0.2536231884057971 diff: 0.05 CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 </s>[SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMILAR][SIMILAR]CC(=O)Nc1ccc(C(=O)NN=Cc2cccc([N+](=O)[O-])c2OC(=O)c2cccc(C)c2)cc1 0.20[/SIMIL



GT: 0.95 Gen: 0.8412698412698413 diff: 0.11 CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 </s>[SIMILAR]CCc1n[nH]c2c(c1=S)CN(C(=O)c1ccccc1)CC2 0.95[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][SIMILAR]CCc1n[nH]c2c1CN(C(=O)c1ccccc1)CC2 0.54[/SIMILAR][HEAVYATOMCOUNT]21[/HEAVYATOMCOUNT][CLOGP]2.53[/CLOGP][NUMROTATABLEBONDS]2[/NUMROTATABLEBONDS][NUMAROMATICCARBOCYCLES]1[/NUMAROMATICCARBOCYCLES][SAS]2.54[/SAS][FRACTIONCSP3]0.31[/FRACTIONCSP3][NUMALIPHATICHETEROCYCLES]1[/NUMALIPHATICHETEROCYCLES][NUMAROMATICHETEROCYCLES]1[/NUMAROMATICHETEROCYCLES][NUMSATURATEDHETEROCYCLES]0[/NUMSATURATE

In [42]:
combined = list(zip(ground_truths, gens, diffs))
combined.sort(key=lambda x: x[0])
ground_truths, gens, diffs = zip(*combined)
rmse = metrics.mean_squared_error(ground_truths, gens, squared=False)
mape = metrics.mean_absolute_percentage_error(ground_truths, gens)
correlation = correlation, pvalue = spearmanr(ground_truths, gens)
generate_plot(ground_truths, gens, diffs,'Similarity', rmse, mape, correlation, invalids, len(ground_truths), (0,1), 0)

