# Meta explainer network for Tox21 dataset   version 2!

## 1: load & clean Tox21 CSV

In [1]:
import pandas as pd

DATA_PATH = "Data_v3/original/tox21.csv"  
df = pd.read_csv(DATA_PATH)

label_cols = [c for c in df.columns if c.startswith(("NR-","SR-"))]
df = df.dropna(subset=["smiles"] + label_cols).reset_index(drop=True)
print("Data shape:", df.shape)


Data shape: (3079, 14)


## 2: compute RDKit Descriptors & Toxicophore Flags

In [2]:
from rdkit import Chem
from rdkit.Chem import Descriptors
import numpy as np

rows = []
for smi in df.smiles:
    m = Chem.MolFromSmiles(smi)
    rows.append({
        "MolWt": Descriptors.MolWt(m),
        "LogP":  Descriptors.MolLogP(m),
        "TPSA":  Descriptors.TPSA(m),
        "HDon":  Descriptors.NumHDonors(m),
        "HAcc":  Descriptors.NumHAcceptors(m),
        "RotB":  Descriptors.NumRotatableBonds(m),
        "RingC": Descriptors.RingCount(m),
        "AromR": Descriptors.NumAromaticRings(m),
        "nitro": int(m.HasSubstructMatch(Chem.MolFromSmarts("[N+](=O)[O-]"))),
        "phenol":int(m.HasSubstructMatch(Chem.MolFromSmarts("[OX2H]"))),
        "carbonyl":int(m.HasSubstructMatch(Chem.MolFromSmarts("[CX3]=O"))),
        "amine": int(m.HasSubstructMatch(Chem.MolFromSmarts("[NX3;H2,H1]"))),
        "halogen":int(any(a.GetSymbol() in ("Cl","Br","F","I") for a in m.GetAtoms()))
    })
df_desc = pd.DataFrame(rows)
print("Descriptor frame:", df_desc.shape)


Descriptor frame: (3079, 13)


## 3: Train / Validation Split & Tokenise SMILES

In [3]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

X_train, X_val, y_train, y_val, desc_tr, desc_val = train_test_split(
    df.smiles, df[label_cols], df_desc, test_size=0.2, random_state=42)

def tokenize(smiles_list):
    return tokenizer(smiles_list, padding=True, truncation=True,
                     return_tensors="pt")

enc_train = tokenize(X_train.tolist())
enc_val   = tokenize(X_val.tolist())

print("Train tokens:", enc_train.input_ids.shape)


  from .autonotebook import tqdm as notebook_tqdm


Train tokens: torch.Size([2463, 267])


## 4: Define ChemBERTa Multi-Label Classier    

In [4]:
import torch, torch.nn as nn
from transformers import AutoModel

class ChemBERTaClassifier(nn.Module):
    def __init__(self, n_labels=12):
        super().__init__()
        self.bert = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.bert.config.hidden_size, n_labels)
        )
    def forward(self, input_ids, attention_mask):
        pooled = self.bert(input_ids, attention_mask).pooler_output
        return self.classifier(pooled)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ChemBERTaClassifier(len(label_cols)).to(device)


## 5: Fine‑Tune ChemBERTa

In [5]:
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F, torch.optim as optim, tqdm

train_ds = TensorDataset(enc_train.input_ids, enc_train.attention_mask,
                         torch.FloatTensor(y_train.values))
loader   = DataLoader(train_ds, batch_size=8, shuffle=True)

opt = optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for input_ids, attn, labels in tqdm.tqdm(loader):
    input_ids, attn, labels = input_ids.to(device), attn.to(device), labels.to(device)
    logits = model(input_ids, attn)
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    opt.zero_grad(); loss.backward(); opt.step()
print("Finished 1 epoch, loss:", loss.item())


100%|██████████| 308/308 [00:13<00:00, 22.27it/s]

Finished 1 epoch, loss: 0.059265751391649246





## 6: Save Model & Tokenizer 

In [6]:
import os

SAVE_DIR = "models/v3/chemberta_tox21"
model.eval()
os.makedirs(SAVE_DIR, exist_ok=True)
torch.save(model.state_dict(), f"{SAVE_DIR}/model.pt")
tokenizer.save_pretrained(SAVE_DIR)
print("✅ Model & tokenizer saved.")


✅ Model & tokenizer saved.


## 7:  Compute SHAP Mean‑Abs Features (All 12 Classes)

In [12]:
import shap, torch, numpy as np, torch.nn as nn, os, time

# ▶︎  choose a very small subset for prototyping
SUB_N = 32                 # ← change to 256 or None later
BG_N  = 8                  # background size (must be < SUB_N)
SAVE  = "Data_v3/SHAP_val"
os.makedirs(SAVE, exist_ok=True)

model.to(device).eval()

# 1️⃣  Slice validation set
ids_sub = enc_val.input_ids[:SUB_N].to(device)          # [SUB_N, S]

# 2️⃣  Get embeddings (GPU)
with torch.no_grad():
    embed_val = model.bert.embeddings(ids_sub).float()  # FP32 to save RAM
N, S, E = embed_val.shape

shap_means = torch.zeros((N, len(label_cols)), device=device)

# 3️⃣  SHAP per class (tiny background, chunked)
start = time.time()
bg = embed_val[:BG_N]                                   # background tensor

for cls in range(len(label_cols)):
    head = nn.Sequential(nn.Identity(), nn.Linear(E, 1)).to(device)
    head[1].weight.data = model.classifier[1].weight.data[cls:cls+1]
    head[1].bias.data   = model.classifier[1].bias.data[cls:cls+1]
    explainer = shap.DeepExplainer(head, bg)

    # run SHAP on this tiny batch
    vals = explainer.shap_values(embed_val)[0]          # NumPy [N, S, 768]
    shap_means[:, cls] = torch.from_numpy(np.abs(vals).mean(axis=(1, 2))).to(device)

print(f"⏱️ SHAP for {N} molecules finished in {(time.time()-start):.1f}s")

# 4️⃣  Save
np.save(f"{SAVE}/shap_means.npy", shap_means.cpu().numpy())
print("✅ Saved shap_means:", shap_means.shape, "→", f'{SAVE}/shap_means.npy')

⏱️ SHAP for 32 molecules finished in 126.5s
✅ Saved shap_means: torch.Size([32, 12]) → Data_v3/SHAP_val/shap_means.npy


## 8: Build Meta‑Explainer Dataset

In [13]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# 1️⃣ Load SHAP-means (produced in Cell 7-lite)
shap_means = np.load("Data_v3/SHAP_val/shap_means.npy")           # shape (SUB_N, 12)
N_sub = shap_means.shape[0]

# 2️⃣ Take the matching slice of descriptors & labels
desc_val_sub = desc_val.iloc[:N_sub].reset_index(drop=True)       # (SUB_N, d)
y_val_sub    = y_val.iloc[:N_sub].reset_index(drop=True)          # (SUB_N, 12)

# 3️⃣ Concatenate descriptor features + SHAP features
meta_X = np.hstack([desc_val_sub.values, shap_means])             # (SUB_N, d+12)
meta_y = y_val_sub.values                                         # (SUB_N, 12)

# 4️⃣ Train / test split
X_tr, X_te, y_tr, y_te = train_test_split(
    meta_X, meta_y, test_size=0.2, random_state=42
)

# 5️⃣ Build DataLoaders
tr_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(y_tr))
te_ds = TensorDataset(torch.FloatTensor(X_te), torch.FloatTensor(y_te))
tr_loader = DataLoader(tr_ds, batch_size=16, shuffle=True)
te_loader = DataLoader(te_ds, batch_size=16)

print(f"✅ Meta-Explainer tensors ready — X_tr: {X_tr.shape}, X_te: {X_te.shape}")

✅ Meta-Explainer tensors ready — X_tr: (25, 25), X_te: (7, 25)


## 9: Train Meta‑Explainer MLP

In [14]:
class MetaExplainer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, out_dim), nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)

meta = MetaExplainer(in_dim=X_tr.shape[1], out_dim=y_tr.shape[1]).to(device)
opt = torch.optim.Adam(meta.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()

meta.train()
for xb, yb in tr_loader:
    xb, yb = xb.to(device), yb.to(device)
    pred = meta(xb)
    loss = loss_fn(pred, yb)
    opt.zero_grad(); loss.backward(); opt.step()
print("✅ Meta‑Explainer trained 1 epoch, loss:", loss.item())


✅ Meta‑Explainer trained 1 epoch, loss: 2.5703611373901367


## 10: Generate an Explanation for a New SMILES

In [24]:
# Cell 10 – Interactive SMILES → Name / CID lookup → Explanation
# --------------------------------------------------------------
import requests, shap, torch, numpy as np, torch.nn as nn
from rdkit import Chem
from rdkit.Chem import Descriptors

E = model.bert.config.hidden_size  # 768

def pubchem_name_cid(smiles: str):
    """Return (name, cid) via PubChem PUG REST; (None, None) if not found."""
    url = ("https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/"
           f"{smiles}/property/Title,IUPACName,CID/JSON")
    try:
        js = requests.get(url, timeout=10).json()
        props = js["PropertyTable"]["Properties"][0]
        name = props.get("Title") or props.get("IUPACName")
        cid  = props.get("CID")
        return name, cid
    except Exception:
        return None, None

def explain_smiles(smiles: str, *, top_k=2, p_thresh=0.5):
    # 0) metadata
    name, cid = pubchem_name_cid(smiles)

    # 1) descriptors / flags
    m = Chem.MolFromSmiles(smiles)
    d = np.array([
        Descriptors.MolWt(m), Descriptors.MolLogP(m), Descriptors.TPSA(m),
        Descriptors.NumHDonors(m), Descriptors.NumHAcceptors(m),
        Descriptors.NumRotatableBonds(m), Descriptors.RingCount(m),
        Descriptors.NumAromaticRings(m),
        int(m.HasSubstructMatch(Chem.MolFromSmarts("[N+](=O)[O-]"))),
        int(m.HasSubstructMatch(Chem.MolFromSmarts("[OX2H]"))),
        int(m.HasSubstructMatch(Chem.MolFromSmarts("[CX3]=O"))),
        int(m.HasSubstructMatch(Chem.MolFromSmarts("[NX3;H2,H1]"))),
        int(any(a.GetSymbol() in ("Cl","Br","F","I") for a in m.GetAtoms()))
    ], dtype=float)

    # 2) per‑class mean‑abs SHAP (single mol → tiny tensors, runs fast)
    enc = tokenizer(smiles, return_tensors="pt").to(device)
    with torch.no_grad():
        emb = model.bert.embeddings(enc.input_ids)
    shap_vec = []
    for i in range(len(label_cols)):
        head = nn.Sequential(nn.Identity(), nn.Linear(E, 1)).to(device)
        head[1].weight.data = model.classifier[1].weight.data[i:i+1]
        head[1].bias.data   = model.classifier[1].bias.data[i:i+1]
        v = shap.DeepExplainer(head, emb).shap_values(emb)[0]
        shap_vec.append(np.abs(v).mean())
    shap_vec = np.array(shap_vec)

    # 3) meta‑explainer prediction
    feats = np.hstack([d, shap_vec])
    with torch.no_grad():
        probs = meta(torch.FloatTensor(feats).unsqueeze(0).to(device)).cpu().numpy()[0]

    # 4) text explanation
    positives = [label_cols[i] for i,p in enumerate(probs) if p > p_thresh]
    reasons = []
    if d[3] > 2:   reasons.append("high H‑bond donor count")
    if d[0] > 500: reasons.append("large MolWt")
    if d[8]:       reasons.append("nitro group")
    if d[9]:       reasons.append("phenolic hydroxyl")
    if d[12]:      reasons.append("halogen substituent")
    if not reasons:
        idxs = shap_vec.argsort()[-top_k:][::-1]
        reasons = [f"strong model signal for {label_cols[i]}" for i in idxs]

    header = ""
    if name: header += f"**{name}** "
    if cid:  header += f"(CID {cid}) "
    text = (f"{header}: Model predicts toxicity for {', '.join(positives)} "
            f"because of {' and '.join(reasons)}.")
    return text, probs

# ── Interactive prompt ────────────────────────────────────
smiles_in = input("Enter your drug SMILES: ").strip()
explanation, prob_vec = explain_smiles(smiles_in)
print("\n" + explanation)
print("Probabilities:", np.round(prob_vec, 3))



: Model predicts toxicity for NR-AR, NR-AhR, SR-HSE, SR-p53 because of strong model signal for SR-p53 and strong model signal for SR-MMP.
Probabilities: [0.79  0.408 0.887 0.    0.    0.085 0.185 0.    0.003 0.998 0.    1.   ]


## Final output!