In [1]:
!pip install --quiet deepchem rdkit transformers accelerate bitsandbytes peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.4/552.4 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.7/36.7 MB[0m [31m52.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.7/60.7 MB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [2]:
!pip install --quiet trl datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.5/540.5 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25h

In [3]:
import os, gc, torch, numpy as np, pandas as pd
import deepchem as dc
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rdkit import Chem

# clear vram
torch.cuda.empty_cache()
gc.collect()

# --- 1. DATA LOAD (Direct download to bypass deepchem loader issues) ---
print("Fetching ClinTox csv...")
try:
    # try direct link first
    df = pd.read_csv("https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz")
except:
    # fallback local
    print("Download failed, using local file...")
    df = pd.read_csv("clintox.csv.gz")

print(f"Raw rows: {len(df)}")

# Sanitize SMILES - drop anything RDKit hates
valid_idxs = []
for idx, row in df.iterrows():
    try:
        mol = Chem.MolFromSmiles(row['smiles'])
        if mol: valid_idxs.append(idx)
    except: pass

df_clean = df.iloc[valid_idxs].reset_index(drop=True)
print(f"Valid rows: {len(df_clean)}")

# --- 2. SCAFFOLD SPLIT ---
print("Running scaffold split...")
# Target columns: FDA_APPROVED, CT_TOX
dataset = dc.data.NumpyDataset(
    X=df_clean['smiles'].values, 
    y=df_clean[['FDA_APPROVED', 'CT_TOX']].values, 
    ids=df_clean['smiles'].values
)
splitter = dc.splits.ScaffoldSplitter()
train_dc, valid_dc, test_dc = splitter.train_valid_test_split(dataset)

print(f"Train: {len(train_dc)}, Test: {len(test_dc)}")

# --- 3. DATASET CLASS & AUGMENTATION ---
def random_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if not mol: return smiles
        return Chem.MolToSmiles(mol, doRandom=True, canonical=False)
    except: return smiles

class ClinToxDataset(Dataset):
    def __init__(self, dc_data, tokenizer, max_len=256, augment=False):
        self.data = []
        self.tokenizer = tokenizer
        
        ids = dc_data.ids # smiles strings
        y = dc_data.y     # [fda, tox]
        
        indices = list(range(len(ids)))
        
        # Heavy upsampling for toxic class (imbalance fix)
        if augment:
            tox_idxs = [i for i in indices if y[i][1] == 1]
            # x5 multiplier for minority class
            indices += tox_idxs * 5 
            np.random.shuffle(indices)

        for i in indices:
            smi_canon = str(ids[i])
            
            # Focus on CT_TOX (index 1)
            label = int(y[i][1]) 
            label_str = "Yes" if label == 1 else "No"
            
            smi_variants = [smi_canon]
            
            # Add noise (random smiles) for training toxic samples
            if augment and label == 1: 
                for _ in range(2):
                    smi_variants.append(random_smiles(smi_canon))
            
            for smi in smi_variants:
                # Prompt format
                txt = f"Task: Clinical Toxicity | SMILES: {smi} | Toxic: {label_str}" + tokenizer.eos_token
                
                enc = tokenizer(
                    txt, 
                    max_length=max_len, 
                    padding="max_length", 
                    truncation=True, 
                    return_tensors="pt"
                )
                self.data.append({
                    "ids": enc["input_ids"][0],
                    "mask": enc["attention_mask"][0],
                    "label": label
                })

    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

# --- 4. MODEL SETUP ---
print("Loading Mistral-7B...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], 
    lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# --- 5. TRAINING ---
train_ds = ClinToxDataset(train_dc, tokenizer, augment=True)
test_ds = ClinToxDataset(test_dc, tokenizer, augment=False)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)

opt = torch.optim.AdamW(model.parameters(), lr=1e-4) # low lr for stability

EPOCHS = 3
ACCUM_STEPS = 8 

print(f"Starting loop ({EPOCHS} epochs)...")
best_score = 0.0

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    pbar = tqdm(train_loader, desc=f"Ep {epoch+1}")
    
    for step, batch in enumerate(pbar):
        ids = batch["ids"].to("cuda")
        mask = batch["mask"].to("cuda")
        
        out = model(input_ids=ids, attention_mask=mask, labels=ids)
        loss = out.loss / ACCUM_STEPS
        loss.backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            opt.step()
            opt.zero_grad()
            
        epoch_loss += loss.item() * ACCUM_STEPS
        pbar.set_postfix({"loss": f"{epoch_loss/(step+1):.4f}"})

    # Validate
    print("Validating...")
    model.eval()
    preds, acts = [], []
    id_yes = tokenizer.encode("Yes", add_special_tokens=False)[0]
    id_no = tokenizer.encode("No", add_special_tokens=False)[0]
    
    with torch.no_grad():
        for item in tqdm(test_ds.data, desc="Testing"):
            full_txt = tokenizer.decode(item["ids"], skip_special_tokens=True)
            # split at prompt end
            query = full_txt.split("Toxic:")[0] + "Toxic:"
            
            inp = tokenizer(query, return_tensors="pt").to("cuda")
            out = model(**inp)
            logits = out.logits[0, -1, [id_no, id_yes]]
            probs = torch.nn.functional.softmax(logits.float(), dim=-1)
            
            preds.append(probs[1].item())
            acts.append(item["label"])
            
    auc = roc_auc_score(acts, preds)
    print(f"Epoch {epoch+1} ROC-AUC: {auc:.4f}")
    
    if auc > 0.94:
        print(f"High score ({auc:.4f}), saving adapter...")
        model.save_pretrained(f"mistral_clintox_epoch_{epoch+1}")
        best_score = auc

print(f"Best Score: {best_score:.4f}")

2026-02-16 22:48:20.499597: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771282100.740486      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771282100.811037      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771282101.412889      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771282101.412956      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771282101.412963      55 computation_placer.cc:177] computation placer alr

Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead




Fetching ClinTox csv...
Raw rows: 1484


[22:48:48] Explicit valence for atom # 0 N, 4, is greater than permitted
[22:48:48] Can't kekulize mol.  Unkekulized atoms: 9
[22:48:48] Can't kekulize mol.  Unkekulized atoms: 4
[22:48:48] Can't kekulize mol.  Unkekulized atoms: 4


Valid rows: 1480
Running scaffold split...
Train: 1184, Test: 148
Loading Mistral-7B...


tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

trainable params: 13,631,488 || all params: 7,255,363,584 || trainable%: 0.1879
Starting loop (3 epochs)...


Ep 1:   0%|          | 0/1400 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)
Ep 1: 100%|██████████| 1400/1400 [43:04<00:00,  1.85s/it, loss=5.7488]


Validating...


Testing: 100%|██████████| 148/148 [00:32<00:00,  4.60it/s]


Epoch 1 ROC-AUC: 0.9768
High score (0.9768), saving adapter...


  return fn(*args, **kwargs)
Ep 2: 100%|██████████| 1400/1400 [43:14<00:00,  1.85s/it, loss=5.6603]


Validating...


Testing: 100%|██████████| 148/148 [00:32<00:00,  4.60it/s]


Epoch 2 ROC-AUC: 0.9913
High score (0.9913), saving adapter...


  return fn(*args, **kwargs)
Ep 3: 100%|██████████| 1400/1400 [43:14<00:00,  1.85s/it, loss=5.6288]


Validating...


Testing: 100%|██████████| 148/148 [00:32<00:00,  4.61it/s]


Epoch 3 ROC-AUC: 0.9877
High score (0.9877), saving adapter...
Best Score: 0.9877
