In [33]:
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import torch
from datasets import load_dataset
from datasets.iterable_dataset import IterableDataset
from transformers import OPTForCausalLM
from chemlactica.utils.model_utils import load_model
from chemlactica.utils.utils import get_tokenizer
import scipy
import numpy as np
from sklearn.metrics import root_mean_squared_error
from rdkit import Chem

In [3]:
data = "RLM"
dataset = load_dataset(
                "csv",
                data_files={
                    "train": f"/auto/home/menuab/code/sft_data/ADME_{data}/train/*.csv",
                    "validation": f"/auto/home/menuab/code/sft_data/ADME_{data}/test/*.csv",
                },
            )

In [10]:
dataset=load_dataset("gayane/freesolv", download_mode='force_redownload')

Downloading readme:   0%|          | 0.00/598 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.79k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/513 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/64 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/65 [00:00<?, ? examples/s]

In [11]:
dataset

DatasetDict({
    train: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 513
    })
    validation: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 64
    })
    test: Dataset({
        features: ['smiles', 'activity'],
        num_rows: 65
    })
})

In [15]:
tokenizer = get_tokenizer("/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Process 2810458 created a tokenizer


In [67]:
model_path_newPT = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/f8caabf2b00f4662ae55939a/checkpoint-2000/"
model_path_newSFT = "/nfs/dgx/raid/chem/checkpoints/facebook/galactica-125m/59c1bdf4a7ba43d283982fe3/checkpoint-112/"
model = load_model(model_path_newSFT, use_flash_attn=True, gradient_checkpointing=False, dtype=torch.float16).to('cuda').eval()
# model = OPTForCausalLM.from_pretrained(model_path_newSFT, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to('cuda').eval()
model.device, model.dtype

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


(device(type='cuda', index=0), torch.float16)

In [68]:
prompt = tokenizer("[START_SMILES]Clc1cnc(NCC2(CN3CCCC3)CC2)nc1[END_SMILES][PROPERTY]activity ", return_tensors="pt").to(model.device)
out = model.generate(prompt.input_ids, do_sample=False, max_length=100)
out = tokenizer.batch_decode(out)[0]
out

'[START_SMILES]Clc1cnc(NCC2(CN3CCCC3)CC2)nc1[END_SMILES][PROPERTY]activity -0.01[/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY]'

In [69]:
ground_truths, gens, diffs = [],[],[]
invalids = 0
for sample in dataset['validation']:
    ground_truth = round(sample['activity'], 2)
    prompt = f"[START_SMILES]{sample['smiles']}[END_SMILES][PROPERTY]activity "
    len_prompt = len(prompt)
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(prompt.input_ids, do_sample=False, max_length=100)
    out = tokenizer.batch_decode(out)[0]
    try:
        gen = float(out[out.find("activity ") + len("activity "):out.find("[/PROPERTY]")])
        diff = abs(ground_truth - gen)
        print("GT:", ground_truth, "Gen:", gen, "diff:", round(diff,2), out )
        ground_truths.append(ground_truth)
        gens.append(gen)
        diffs.append(diff)
    except:
        invalids += 1
        pass

GT: 0.13 Gen: 1.01 diff: 0.88 [START_SMILES]CN(C)C[END_SMILES][PROPERTY]activity 1.01[/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY]-0.01[/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY][/PROPERTY]
GT: 0.18 Gen: 0.01

In [70]:
r, p = scipy.stats.pearsonr(np.array(ground_truths), np.array(gens))
r, p

(0.4542630450433333, 0.0001629349585106703)

In [71]:
root_mean_squared_error(ground_truths, gens)

1.0757766090132281

In [36]:
Chem.MolToSmiles(Chem.MolFromSmiles("CN(C)C(=O)c1ccccc1"))

'CN(C)C(=O)c1ccccc1'

In [19]:
plt.scatter(ground_truths, gens, alpha=0.1)
plt.plot((0,3),(0,3))

NameError: name 'plt' is not defined

In [34]:
gens

[0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.91,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 1.02,
 0.68,
 2.22,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 1.95,
 1.21,
 1.21,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.22,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,
 1.91,
 0.68,
 1.22,
 0.68,
 0.68,
 1.11,
 0.68,
 0.68,
 0.68,
 1.91,
 1.95,
 1.21,
 0.68,
 0.68,
 0.68,
 0.68,
 1.01,
 1.11,
 1.95,
 0.68,
 0.68,
 0.68,
 1.11,
 0.68,
 1.11,
 1.95,
 1.11,
 1.22,
 1.11,
 0.68,
 0.68,
 2.22,
 1.22,
 0.68,
 1.21,
 0.68,
 1.11,
 1.92,
 0.68,
 0.68,
 1.95,
 1.91,
 0.68,
 0.68,
 0.68,
 0.68,
 1.95,
 0.68,
 1.01,
 0.68,
 0.68,
 0.68,
 1.11,
 1.01,
 0.68,
 0.68,
 0.68,
 0.68,
 1.91,
 0.68,
 1.01,
 1.95,
 0.68,
 0.68,
 0.68,
 0.68,
 1.21,
 0.68,
 0.68,
 0.68,
 0.68,
 1.11,

In [35]:
ground_truths

[0.68,
 0.68,
 1.39,
 0.68,
 1.07,
 2.83,
 0.68,
 0.68,
 1.53,
 0.68,
 1.64,
 0.68,
 0.68,
 1.05,
 0.68,
 0.68,
 1.38,
 0.9,
 2.18,
 0.68,
 0.68,
 0.68,
 0.68,
 0.68,
 1.08,
 0.68,
 1.16,
 0.68,
 1.06,
 0.68,
 0.68,
 1.99,
 1.41,
 1.06,
 0.68,
 2.31,
 2.25,
 0.98,
 0.68,
 0.68,
 1.03,
 1.11,
 1.02,
 1.85,
 0.68,
 1.29,
 1.47,
 0.68,
 3.37,
 0.68,
 0.77,
 1.1,
 1.01,
 0.68,
 0.68,
 1.44,
 0.68,
 2.58,
 2.34,
 0.68,
 1.52,
 1.84,
 0.68,
 1.1,
 1.4,
 1.85,
 0.89,
 1.16,
 0.68,
 0.68,
 2.51,
 1.52,
 1.32,
 2.24,
 0.68,
 2.37,
 1.03,
 0.96,
 1.85,
 0.68,
 2.6,
 1.39,
 1.91,
 1.81,
 2.36,
 0.68,
 0.68,
 2.18,
 0.68,
 2.81,
 0.68,
 1.51,
 0.78,
 2.12,
 1.02,
 1.26,
 0.94,
 1.85,
 2.7,
 1.27,
 0.68,
 1.26,
 1.76,
 1.84,
 0.68,
 1.88,
 1.54,
 1.93,
 0.68,
 0.68,
 0.68,
 2.32,
 0.93,
 0.68,
 0.68,
 0.68,
 0.68,
 1.71,
 0.68,
 1.59,
 0.68,
 1.16,
 0.68,
 1.25,
 2.75,
 0.68,
 0.68,
 1.62,
 1.38,
 2.01,
 0.68,
 0.68,
 1.69,
 0.68,
 1.7,
 0.68,
 2.21,
 2.64,
 0.68,
 0.68,
 0.68,
 1.0,
 0.68,
 1.38,


In [1]:
np.random.normal(0, 0.1)

NameError: name 'np' is not defined