In [11]:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

# Install dependencies

In [1]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl rdkit

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m37.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m348.0/348.0 kB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━

# Load model from HF

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model = "google/txgemma-"
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]

model_id = base_model + CHAT_VARIANT

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

# Load Dataset and Clean It

## Known Binders to Q99814 · EPAS1 (taken from BindingDB curated by UCSD)

In [None]:
import pandas as pd

df = pd.read_csv("Known_HIF_Binders.csv")

In [None]:
clean_df = df[["Ligand SMILES", "IC50 (nM)"]].dropna()

# remove rows that contain '<' or '>'
has_censor = clean_df["IC50 (nM)"] \
    .astype(str) \
    .str.contains(r"[<>]")

# count how many rows will be dropped
dropped_count = has_censor.sum()
print(f"Dropping {dropped_count} rows with '<' or '>' in IC50")

# keep only the rows *without* '<' or '>'
clean_df = clean_df.loc[~has_censor].reset_index(drop=True)
clean_df

In [None]:
import re
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski

# Clean & standardize IC50, compute pIC50
def parse_ic50_to_pic50(ic50_str):
    s = str(ic50_str).strip()
    try:
        nm = float(s)
    except ValueError:
        return np.nan

    # convert nM → M
    m = nm * 1e-9
    # compute pIC50
    pic50 = -np.log10(m)
    return pic50

clean_df["pIC50"] = clean_df["IC50 (nM)"].apply(parse_ic50_to_pic50)

# Bin into activity classes
# strong binder if pIC50 ≥ 7 (IC50 ≤ 100 nM), else weak/non-binder
threshold = 7.0
clean_df["activity_class"] = np.where(clean_df["pIC50"] >= threshold, "strong", "weak")

# Compute 2D descriptors via RDKit
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {
            "MolWt": np.nan,
            "TPSA": np.nan,
            "HBD": np.nan,
            "HBA": np.nan,
            "RotBonds": np.nan,
            "LogP": np.nan,
        }
    return {
        "MolWt": Descriptors.MolWt(mol),
        "TPSA": Descriptors.TPSA(mol),
        "HBD": Lipinski.NumHDonors(mol),
        "HBA": Lipinski.NumHAcceptors(mol),
        "RotBonds": Descriptors.NumRotatableBonds(mol),
        "LogP": Crippen.MolLogP(mol),
    }

# apply and expand into separate columns
desc_df = clean_df["Ligand SMILES"].apply(compute_descriptors).apply(pd.Series)
clean_df = pd.concat([clean_df, desc_df], axis=1)
clean_df["is_known_binder"] = True

# View the table
print(clean_df.head())

## Duds (Taken from BindingDB on Compounds that Bind to ESR1 · P03372)

In [None]:
df = pd.read_csv("ESR1_Binders.csv")

In [None]:
esr1_df = df[["Ligand SMILES", "IC50 (nM)"]].dropna()

# remove rows that contain '<' or '>'
has_censor = esr1_df["IC50 (nM)"] \
    .astype(str) \
    .str.contains(r"[<>]")

# count how many rows will be dropped
dropped_count = has_censor.sum()
print(f"Dropping {dropped_count} rows with '<' or '>' in IC50")

# keep only the rows *without* '<' or '>'
esr1_df = esr1_df.loc[~has_censor].reset_index(drop=True)
esr1_df.reset_index(drop=True, inplace=True)
esr1_df

In [None]:
esr1_df["pIC50"] = esr1_df["IC50 (nM)"].apply(parse_ic50_to_pic50)

esr1_df["activity_class"] = np.where(esr1_df["pIC50"] >= threshold, "strong", "weak")


# apply and expand into separate columns
temp_df = esr1_df["Ligand SMILES"].apply(compute_descriptors).apply(pd.Series)
esr1_df = pd.concat([esr1_df, temp_df], axis=1)
esr1_df["is_known_binder"] = False


print(esr1_df.head())

## Joining the known and unknown binders

In [None]:
all_binders = pd.concat([clean_df, esr1_df], axis=0, ignore_index=True)

In [None]:
# mix up dataset to avoid patterns
perm = np.random.permutation(len(all_binders))
all_binders = all_binders.iloc[perm].reset_index(drop=True)

all_binders.to_csv("all_binders.csv", index=False)

## Turning CSV to jsonl

In [2]:
import pandas as pd

all_binders = pd.read_csv("all_binders.csv")

test = all_binders.head(1)


In [3]:
import json

def make_example(row):
    prompt = (
        "From the following information about a ligand, predict whether it can bind to the HIF-2α protein.\n\n"
        f"This ligand is represented by the SMILES string {row['Ligand SMILES']}, and exhibits an IC50 of "
        f"{row['IC50 (nM)']} nM (pIC50 = {row['pIC50']:.2f}). It has a molecular weight of {row['MolWt']:.2f} Da, "
        f"a topological polar surface area of {row['TPSA']:.2f} Å², {row['HBD']} hydrogen bond donor"
        f"{'s' if row['HBD'] != 1 else ''}, {row['HBA']} hydrogen bond acceptor"
        f"{'s' if row['HBA'] != 1 else ''}, and {row['RotBonds']} rotatable bond"
        f"{'s' if row['RotBonds'] != 1 else ''}, with a logP of {row['LogP']:.2f}.\n"
    )
    completion = "Answer: Yes, it binds to HIF-2α<eos>" if row["is_known_binder"] else "Answer: No, it doesn't bind to HIF-2α<eos>"
    return {"prompt": prompt, "bind": completion}


with open("train_hif_binding.jsonl","w") as fout:
    for _, row in all_binders.iterrows():
        ex = make_example(row)
        fout.write(json.dumps(ex) + "\n")

In [4]:
import json

with open("train_hif_binding.jsonl","r") as f:
    binders = [json.loads(line) for line in f]

# Create formatting function for LoRA later
def formatting_func(example):
    return f"{example['prompt']}\n{example['bind']}"

print(formatting_func(binders[0]))

From the following information about a ligand, predict whether it can bind to the HIF-2α protein.

This ligand is represented by the SMILES string O[C@H]1c2c(CC1(F)F)c(Oc1cc(F)cc(F)c1)ccc2C#N, and exhibits an IC50 of 35.0 nM (pIC50 = 7.46). It has a molecular weight of 323.25 Da, a topological polar surface area of 53.25 Å², 1.0 hydrogen bond donor, 3.0 hydrogen bond acceptors, and 2.0 rotatable bonds, with a logP of 3.85.

Answer: Yes, it binds to HIF-2α<eos>


In [5]:
data = (
    all_binders
    .apply(make_example, axis=1, result_type="expand")
    .rename(columns={"prompt":"input", "bind":"output"})
)
print(data.head())


                                               input  \
0  From the following information about a ligand,...   
1  From the following information about a ligand,...   
2  From the following information about a ligand,...   
3  From the following information about a ligand,...   
4  From the following information about a ligand,...   

                                       output  
0        Answer: Yes, it binds to HIF-2α<eos>  
1        Answer: Yes, it binds to HIF-2α<eos>  
2        Answer: Yes, it binds to HIF-2α<eos>  
3  Answer: No, it doesn't bind to HIF-2α<eos>  
4  Answer: No, it doesn't bind to HIF-2α<eos>  


## Spliting train test

In [9]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

train_data

Unnamed: 0,input,output
240,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
812,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1240,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1280,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1084,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
...,...,...
1130,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"
1294,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
860,"From the following information about a ligand,...","Answer: No, it doesn't bind to HIF-2α<eos>"
1459,"From the following information about a ligand,...","Answer: Yes, it binds to HIF-2α<eos>"


# Fine tuning the model (finally 😱)