In [None]:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

# Install dependencies

In [None]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl rdkit

# Load model from HF

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model = "google/txgemma-"
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]

model_id = base_model + CHAT_VARIANT

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

# Load Dataset and Clean It

## Known Binders (taken from BindingDB curated by UCSD)

In [None]:
import pandas as pd

df = pd.read_csv("Known_HIF_Binders.csv")

In [None]:
clean_df = df[["Ligand SMILES", "IC50 (nM)"]].dropna()

# remove rows that contain '<' or '>'
has_censor = clean_df["IC50 (nM)"] \
    .astype(str) \
    .str.contains(r"[<>]")

# count how many rows will be dropped
dropped_count = has_censor.sum()
print(f"Dropping {dropped_count} rows with '<' or '>' in IC50")

# keep only the rows *without* '<' or '>'
clean_df = clean_df.loc[~has_censor].reset_index(drop=True)
clean_df

In [None]:
import re
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski

# --- 2) Clean & standardize IC50, compute pIC50 ---
def parse_ic50_to_pic50(ic50_str):
    """Convert a string like '<5' or '200' (in nM) to pIC50."""
    # strip any whitespace
    s = str(ic50_str).strip()
    try:
        nm = float(s)
    except ValueError:
        return np.nan  # unparseable

    # convert nM → M
    m = nm * 1e-9
    # pIC50
    pic50 = -np.log10(m)
    return pic50

clean_df["pIC50"] = clean_df["IC50 (nM)"].apply(parse_ic50_to_pic50)

# --- 3) Bin into activity classes ---
# strong binder if pIC50 ≥ 7 (IC50 ≤ 100 nM), else weak/non-binder
threshold = 7.0
clean_df["activity_class"] = np.where(clean_df["pIC50"] >= threshold, "strong", "weak")

# --- 4) Compute 2D descriptors via RDKit ---
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {
            "MolWt": np.nan,
            "TPSA": np.nan,
            "HBD": np.nan,
            "HBA": np.nan,
            "RotBonds": np.nan,
            "LogP": np.nan,
        }
    return {
        "MolWt": Descriptors.MolWt(mol),
        "TPSA": Descriptors.TPSA(mol),
        "HBD": Lipinski.NumHDonors(mol),
        "HBA": Lipinski.NumHAcceptors(mol),
        "RotBonds": Descriptors.NumRotatableBonds(mol),
        "LogP": Crippen.MolLogP(mol),
    }

# apply and expand into separate columns
desc_df = clean_df["Ligand SMILES"].apply(compute_descriptors).apply(pd.Series)
clean_df = pd.concat([clean_df, desc_df], axis=1)
clean_df["is_known_binder"] = True

# --- 5) View the table ---
print(clean_df.head())

## Duds (Taken from BindingDB on Compounds that Bind to ESR1)

In [None]:
df = pd.read_csv("ESR1_Binders.csv")

In [None]:
esr1_df = df[["Ligand SMILES", "IC50 (nM)"]].dropna()

# remove rows that contain '<' or '>'
has_censor = esr1_df["IC50 (nM)"] \
    .astype(str) \
    .str.contains(r"[<>]")

# count how many rows will be dropped
dropped_count = has_censor.sum()
print(f"Dropping {dropped_count} rows with '<' or '>' in IC50")

# keep only the rows *without* '<' or '>'
esr1_df = esr1_df.loc[~has_censor].reset_index(drop=True)
esr1_df.reset_index(drop=True, inplace=True)
esr1_df

In [None]:
esr1_df["pIC50"] = esr1_df["IC50 (nM)"].apply(parse_ic50_to_pic50)

esr1_df["activity_class"] = np.where(esr1_df["pIC50"] >= threshold, "strong", "weak")


# apply and expand into separate columns
temp_df = esr1_df["Ligand SMILES"].apply(compute_descriptors).apply(pd.Series)
esr1_df = pd.concat([esr1_df, temp_df], axis=1)
esr1_df["is_known_binder"] = False


print(esr1_df.head())

## Joining the known and unknown binders

In [None]:
all_binders = pd.concat([clean_df, esr1_df], axis=0, ignore_index=True)

In [None]:
perm = np.random.permutation(len(all_binders))
all_binders = all_binders.iloc[perm].reset_index(drop=True)

all_binders.to_csv("all_binders.csv", index=False)

# Fine tuning the model (finally 😱)