In [1]:
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch
from transformers import OPTForCausalLM, AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset
from datasets.iterable_dataset import IterableDataset
from transformers import OPTForCausalLM
from chemlactica.utils.model_utils import load_model
from chemlactica.utils.utils import get_tokenizer
import scipy
import numpy as np
import numpy
import random
from sklearn.metrics import root_mean_squared_error
from rdkit import Chem

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
torch.manual_seed(42)
random.seed(42)
numpy.random.seed(42)

In [5]:
gal_tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66")
gem_tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/GemmaTokenizer")
len(gal_tokenizer), len(gem_tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(50066, 256000)

In [3]:
model_125m_20k_9954 = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/9954e52e400b43d18d3a40f6/checkpoint-20480"
model_2b_11k_d6e6 = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/d6e6a76e91814ad68d5fa264/checkpoint-11000"
model_125m_18k_1f28 = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/1f289ff103034364bd27e1c3/checkpoint-18000/"
model_2b_12k_0717 = "/nfs/dgx/raid/chem/checkpoints/h100/google/gemma-2b/0717d445bcf44e31b2887892/checkpoint-12000"


In [4]:
# chem 125 on paper
hPPB = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/b1b66dbf0b834f1c9da0a444"
rPPB = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/b9da010c01844cefb8b03426"
HLM = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/803746806706444c8f36d6e9"
MD1 = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/3261ee682d58403bb8b4b29a"
Sol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/f9218c28afc64175b07813fd"
RLM = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/310a879774b842d4af6c3822"
freesolv = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/07677b1fa3014ab1bf445f16"
esol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/421bf81ba8754fe9ac651eb0"

In [5]:
# gemma on paper
HLM = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/c247f1a414ae4bb9ad855748"
RLM = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/6a1f3428e7004e35a58e85c3"
MD1 = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/27190d95359b4232ad86a67b"
rPPB = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/27190d95359b4232ad86a67b"
hPPB = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/8898500c69594b6880765375"
Sol = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/1f37207f5fc048f8990def13"
freesolv = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/4275d646fbdb4cb1bc117047"
esol = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/09bc2ca81f2248a9940a9cc6"
lipo = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/02792a595f0d4a938f9512bd"

In [23]:
# chem 125 after tuning
hPPB = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/de959a11aec4462f9c696bbe"
rPPB = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/68e37bd1adcc4354ab2437d5"
HLM = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/b01e655d4a574c80b2ee198f"
MD1 = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/2681011a699a439ab96f41c3"
Sol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/faf5d9872128422abefe3118"
RLM = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/e17580ef6fdc44c59d3e026c"
freesolv = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/0f3cdd732ecf4ce3a206a888"
esol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/96c54bbc02a147819e53f150"
lipo = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/967ac14a399740e7bd82a8f3"

In [4]:
path = "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/02792a595f0d4a938f9512bd"
!ls "/nfs/dgx/raid/chem/checkpoints/google/gemma-2b/02792a595f0d4a938f9512bd"


last


In [6]:
# gemma after tuning
hPPB = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/d2a948257d424b8d8cd8017c"
rPPB = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/6c5b04cb7ef24993ba7c2f86"
HLM = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/5ab970c4f78249059fa72502"
MD1 = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/adf06bb2d1784b4992791447"
Sol = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/c6dd3dbe377c447ebb2e82b6"
RLM = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/22e823e19ef74c33afd4735d"
freesolv = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717/084da1cf0bfb4cea87845b7e"
esol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/96c54bbc02a147819e53f150"
lipo = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/best_0717c1da52ccb23c499baf3598de"

In [None]:
# chem 1b after tuning
hPPB = "/nfs/dgx/raid/chem/experiments/checkpoints/facebook/galactica-1.3b/4ce3b6e46f934edda37b8689"
HLM = "/nfs/dgx/raid/chem/experiments/checkpoints/facebook/galactica-1.3b/336e63d54ab24182a2313eb0"
MD1 = "/nfs/dgx/raid/chem/experiments/checkpoints/facebook/galactica-1.3b/14d0a075f9bb4b0aa03a3b4b"
RLM = "/nfs/dgx/raid/chem/experiments/checkpoints/facebook/galactica-1.3b/75d543a26e9a47b28f1e2181"
Sol = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/facebook/galactica-1.3b/eb2a0797c503418aaeae141e"
rPPB = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/facebook/galactica-1.3b/c6304620e3a5496c989be694"
freesolv = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/facebook/galactica-1.3b/ed0a7af4d7b2438ba10db96e"
esol = "/nfs/dgx/raid/chem/checkpoints/galactica-125m/1f289ff103034364bd27e1c3/96c54bbc02a147819e53f150"
lipo = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/facebook/galactica-1.3b/51dfdd64f4c0423f855d55b9"

In [2]:
path = "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/gemma-2b/d45771071dee4cc8b7b6c367"
!ls "/nfs/ap/mnt/sxtn2/chem/experiments/checkpoints/google/gemma-2b/d45771071dee4cc8b7b6c367"


last


In [6]:
dataset = load_dataset(
                "/auto/home/menuab/code/sft_data/ADME_Sol"
                # "gayane/esol"
            )

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 1391
    })
    validation: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 347
    })
    test: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 435
    })
})

In [9]:
model = AutoModelForCausalLM.from_pretrained(
    path + "/last" ,torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to("cuda:0").eval()
model.device, model.dtype


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:26<00:00, 13.06s/it]


(device(type='cuda', index=0), torch.bfloat16)

In [10]:
tokenizer = gem_tokenizer
ground_truths, gens, diffs = [],[],[]
invalids = 0
for sample in dataset['test']:
    ground_truth = round(sample['activity'], 2)
    prompt = f"</s>[START_SMILES]{sample['smiles']}[END_SMILES][PROPERTY]activity"
    len_prompt = len(prompt)
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(prompt.input_ids, do_sample=False, eos_token_id=tokenizer.encode('[/PROPERTY]')[0], max_new_tokens=100)
    out = tokenizer.batch_decode(out)[0]
    try:
        gen = float(out[out.find("activity ") + len("activity "):out.find("[/PROPERTY]")])
        diff = abs(ground_truth - gen)
        print("GT:", ground_truth, "Gen:", gen, "diff:", round(diff,2), out )
        ground_truths.append(ground_truth)
        gens.append(gen)
        diffs.append(diff)
    except:
        print('***')
        print("GT:", ground_truth, out)
        invalids += 1
        pass

GT: 0.51 Gen: 0.69 diff: 0.18 </s>[START_SMILES]C1(=O)NC(=O)NC(=O)C1(O)C2(O)C(=O)NC(=O)NC2(=O)[END_SMILES][PROPERTY]activity 0.69[/PROPERTY]
GT: -1.42 Gen: -0.89 diff: 0.53 </s>[START_SMILES]CCCCCC1CCCC1[END_SMILES][PROPERTY]activity -0.89[/PROPERTY]
GT: 0.34 Gen: 0.55 diff: 0.21 </s>[START_SMILES]CCCCCC(=O)OCC[END_SMILES][PROPERTY]activity 0.55[/PROPERTY]
GT: 0.18 Gen: 0.34 diff: 0.16 </s>[START_SMILES]CCC1(C(=O)NC(=O)NC1=O)C2=CCC3CCC2C3[END_SMILES][PROPERTY]activity 0.34[/PROPERTY]
GT: -0.35 Gen: -0.25 diff: 0.1 </s>[START_SMILES]CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1[END_SMILES][PROPERTY]activity -0.25[/PROPERTY]
GT: 0.07 Gen: 0.11 diff: 0.04 </s>[START_SMILES]Clc1ccc(cc1)N(=O)=O[END_SMILES][PROPERTY]activity 0.11[/PROPERTY]
GT: -0.68 Gen: -1.02 diff: 0.34 </s>[START_SMILES]c1(Br)c(Br)cc(Br)cc1[END_SMILES][PROPERTY]activity -1.02[/PROPERTY]
GT: 1.23 Gen: 1.17 diff: 0.06 </s>[START_SMILES]CCC(C)CO[END_SMILES][PROPERTY]activity 1.17[/PROPERTY]
GT: -0.39 Gen: -0.67 diff: 0.28 </s>[START_SMILES]CC

In [11]:
r, p = scipy.stats.pearsonr(ground_truths, gens)
rmse = root_mean_squared_error(ground_truths, gens)
r, p, rmse

(0.939608205696396, 1.6491045008987505e-53, 0.3409175541990497)

In [26]:
r, p = scipy.stats.pearsonr(ground_truths, gens)
rmse = root_mean_squared_error(ground_truths, gens)
r, p, rmse

(0.6563720844639327, 1.4584098168426824e-76, 0.5955841319789665)

In [12]:
Chem.MolToSmiles(Chem.MolFromSmiles("CN(C(=O)Cc1ccc(S(C)(=O)=O)cc1)C1CCN(Cc2ccc(C(F)(F)F)cc2)CC1"))

'CN(C(=O)Cc1ccc(S(C)(=O)=O)cc1)C1CCN(Cc2ccc(C(F)(F)F)cc2)CC1'

In [None]:
plt.scatter(ground_truths, gens, alpha=0.1)
plt.plot((0,3),(0,3))

NameError: name 'plt' is not defined

In [None]:
gens

[0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.91,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 1.02,
 0.68,
 2.22,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 1.95,
 1.21,
 1.21,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,
 1.91,
 0.68,
 1.22,
 0.68,
 0.68,
 1.11,
 0.68,
 0.68,
 0.68,
 1.91,
 1.95,
 1.21,
 0.68,
 0.68,
 0.68,
 0.68,
 1.01,
 1.11,
 1.95,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 1.11,
 1.95,
 1.11,
 1.22,
 1.11,
 0.68,
 0.68,
 2.22,
 1.22,
 0.68,
 1.21,
 0.68,
 1.11,
 1.92,
 0.68,
 0.68,
 1.95,
 1.91,
 0.68,
 0.68,
 0.68,
 0.68,
 1.95,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.11,
 1.01,
 0.68,
 0.68,
 0.68,
 0.68,
 1.91,
 0.68,
 1.01,
 1.95,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,

In [None]:
ground_truths

[0.68,
 0.68,
 1.39,
 0.68,
 1.07,
 2.83,
 0.68,
 0.68,
 1.53,
 0.68,
 1.64,
 0.68,
 0.68,
 1.05,
 0.68,
 0.68,
 1.38,
 0.9,
 2.18,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.08,
 0.68,
 1.16,
 0.68,
 1.06,
 0.68,
 0.68,
 1.99,
 1.41,
 1.06,
 0.68,
 2.31,
 2.25,
 0.98,
 0.68,
 0.68,
 1.03,
 1.11,
 1.02,
 1.85,
 0.68,
 1.29,
 1.47,
 0.68,
 3.37,
 0.68,
 0.77,
 1.1,
 1.01,
 0.68,
 0.68,
 1.44,
 0.68,
 2.58,
 2.34,
 0.68,
 1.52,
 1.84,
 0.68,
 1.1,
 1.4,
 1.85,
 0.89,
 1.16,
 0.68,
 0.68,
 2.51,
 1.52,
 1.32,
 2.24,
 0.68,
 2.37,
 1.03,
 0.96,
 1.85,
 0.68,
 2.6,
 1.39,
 1.91,
 1.81,
 2.36,
 0.68,
 0.68,
 2.18,
 0.68,
 2.81,
 0.68,
 1.51,
 0.78,
 2.12,
 1.02,
 1.26,
 0.94,
 1.85,
 2.7,
 1.27,
 0.68,
 1.26,
 1.76,
 1.84,
 0.68,
 1.88,
 1.54,
 1.93,
 0.68,
 0.68,
 0.68,
 2.32,
 0.93,
 0.68,
 0.68,
 0.68,
 0.68,
 1.71,
 0.68,
 1.59,
 0.68,
 1.16,
 0.68,
 1.25,
 2.75,
 0.68,
 0.68,
 1.62,
 1.38,
 2.01,
 0.68,
 0.68,
 1.69,
 0.68,
 1.7,
 0.68,
 2.21,
 2.64,
 0.68,
 0.68,
 0.68,
 1.0,
 0.68,
 1.38,


In [None]:
np.random.normal(0, 0.1)

NameError: name 'np' is not defined