In [None]:
import os
os.environ["HF_TOKEN"] =

# Install dependencies

In [2]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl rdkit tf-keras

# Load model from HF

In [3]:
!pip install --upgrade --force-reinstall "numpy<2.0"

Defaulting to user installation because normal site-packages is not writeable
Collecting numpy<2.0
  Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[0mSuccessfully installed numpy-1.26.4


In [4]:
import numpy
print(numpy.__version__) 

1.26.4


In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model = "google/txgemma-"
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]

model_id = base_model + CHAT_VARIANT

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

  from .autonotebook import tqdm as notebook_tqdm
2025-05-07 02:23:53.235807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746584633.254094   21654 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746584633.259493   21654 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746584633.271473   21654 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746584633.271481   21654 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746584633.271483   21654

# Load dataset and prepare train test split

In [12]:
import json

with open("train_hif_binding.jsonl","r") as f:
    binders = [json.loads(line) for line in f]

records = [
    {"input": ex["prompt"], "output": ex["bind"]}
    for ex in binders
]

def formatting_func(example):
    return f"{example['input']}\n{example['output']}"

print(formatting_func(records[0]))

From the following information about a ligand, predict whether it can bind to the HIF-2α protein.

This ligand is represented by the SMILES string O[C@H]1c2c(CC1(F)F)c(Oc1cc(F)cc(F)c1)ccc2C#N, and exhibits an IC50 of 35.0 nM (pIC50 = 7.46). It has a molecular weight of 323.25 Da, a topological polar surface area of 53.25 Å², 1.0 hydrogen bond donor, 3.0 hydrogen bond acceptors, and 2.0 rotatable bonds, with a logP of 3.85.

Answer: Yes, it binds to HIF-2α<eos>


In [13]:
import pandas as pd

data = pd.DataFrame([
    {
        "input": ex["prompt"],
        "output": ex["bind"]
    }
    for ex in binders
])

data

Unnamed: 0,input,output
0,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
2,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
3,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
4,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
...,...,...
1993,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1994,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1995,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1996,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"


## Spliting train test

In [14]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.1, random_state=42)

train_data.reset_index(drop=True, inplace=True)
test_data.reset_index(drop=True, inplace=True)

train_data

Unnamed: 0,input,output
0,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
2,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
3,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
4,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
...,...,...
1793,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1794,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1795,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1796,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"


# Fine tuning the model (finally 😱)

In [15]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [16]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)



In [18]:
import transformers
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

train_records = train_data.to_dict(orient="records")
test_records  = test_data.to_dict(orient="records")

hf_train = Dataset.from_list(train_records)
hf_eval  = Dataset.from_list(test_records)

trainer = SFTTrainer(
    model=model,
    train_dataset = hf_train,
    eval_dataset  = hf_eval,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="./outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)


Applying formatting function to train dataset: 100%|██████████| 1798/1798 [00:00<00:00, 17969.61 examples/s]
Converting train dataset to ChatML: 100%|██████████| 1798/1798 [00:00<00:00, 24910.35 examples/s]
Adding EOS to train dataset: 100%|██████████| 1798/1798 [00:00<00:00, 19269.72 examples/s]
Tokenizing train dataset: 100%|██████████| 1798/1798 [00:00<00:00, 2429.27 examples/s]
Truncating train dataset: 100%|██████████| 1798/1798 [00:00<00:00, 288113.03 examples/s]
Applying formatting function to eval dataset: 100%|██████████| 200/200 [00:00<00:00, 16697.73 examples/s]
Converting eval dataset to ChatML: 100%|██████████| 200/200 [00:00<00:00, 19595.43 examples/s]
Adding EOS to eval dataset: 100%|██████████| 200/200 [00:00<00:00, 15920.99 examples/s]
Tokenizing eval dataset: 100%|██████████| 200/200 [00:00<00:00, 2323.40 examples/s]
Truncating eval dataset: 100%|██████████| 200/200 [00:00<00:00, 62925.57 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Sinc

In [19]:
trainer.train()

trainer.save_model()
trainer.tokenizer.save_pretrained(
    trainer.args.output_dir
)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
5,3.1524
10,1.0352
15,0.6802
20,0.6793
25,0.5946
30,0.594
35,0.57
40,0.5617
45,0.5418
50,0.5145


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


('./outputs/tokenizer_config.json',
 './outputs/special_tokens_map.json',
 './outputs/tokenizer.json')

# Test the fine-tuned model

In [60]:
i = 13

In [None]:
import torch
from torch.amp import autocast

model = model.to("cuda").eval()

prompt = test_data.iloc[i]["input"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Note: first argument is device type, not a keyword
with autocast("cuda", dtype=torch.bfloat16):
    outputs = model.generate(**inputs, max_new_tokens=8)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
print(f"The correct output should be:\n{test_data.iloc[i]['output']}")