# Meta explainer network for Tox21 dataset   version 2!

## 1: load & clean Tox21 CSV

In [2]:
import pandas as pd

DATA_PATH = "Data_v3/original/tox21.csv"  
df = pd.read_csv(DATA_PATH)

label_cols = [c for c in df.columns if c.startswith(("NR-","SR-"))]
df = df.dropna(subset=["smiles"] + label_cols).reset_index(drop=True)
print("Data shape:", df.shape)


Data shape: (3079, 14)


## 2: compute RDKit Descriptors & Toxicophore Flags

In [3]:
from rdkit import Chem
from rdkit.Chem import Descriptors
import numpy as np

rows = []
for smi in df.smiles:
    m = Chem.MolFromSmiles(smi)
    rows.append({
        "MolWt": Descriptors.MolWt(m),
        "LogP":  Descriptors.MolLogP(m),
        "TPSA":  Descriptors.TPSA(m),
        "HDon":  Descriptors.NumHDonors(m),
        "HAcc":  Descriptors.NumHAcceptors(m),
        "RotB":  Descriptors.NumRotatableBonds(m),
        "RingC": Descriptors.RingCount(m),
        "AromR": Descriptors.NumAromaticRings(m),
        "nitro": int(m.HasSubstructMatch(Chem.MolFromSmarts("[N+](=O)[O-]"))),
        "phenol":int(m.HasSubstructMatch(Chem.MolFromSmarts("[OX2H]"))),
        "carbonyl":int(m.HasSubstructMatch(Chem.MolFromSmarts("[CX3]=O"))),
        "amine": int(m.HasSubstructMatch(Chem.MolFromSmarts("[NX3;H2,H1]"))),
        "halogen":int(any(a.GetSymbol() in ("Cl","Br","F","I") for a in m.GetAtoms()))
    })
df_desc = pd.DataFrame(rows)
print("Descriptor frame:", df_desc.shape)


Descriptor frame: (3079, 13)


## 3: Train / Validation Split & Tokenise SMILES

In [4]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

X_train, X_val, y_train, y_val, desc_tr, desc_val = train_test_split(
    df.smiles, df[label_cols], df_desc, test_size=0.2, random_state=42)

def tokenize(smiles_list):
    return tokenizer(smiles_list, padding=True, truncation=True,
                     return_tensors="pt")

enc_train = tokenize(X_train.tolist())
enc_val   = tokenize(X_val.tolist())

print("Train tokens:", enc_train.input_ids.shape)


  from .autonotebook import tqdm as notebook_tqdm


Train tokens: torch.Size([2463, 267])


## 4: Define ChemBERTa Multi-Label Classier    

In [5]:
import torch, torch.nn as nn
from transformers import AutoModel

class ChemBERTaClassifier(nn.Module):
    def __init__(self, n_labels=12):
        super().__init__()
        self.bert = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.bert.config.hidden_size, n_labels)
        )
    def forward(self, input_ids, attention_mask):
        pooled = self.bert(input_ids, attention_mask).pooler_output
        return self.classifier(pooled)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ChemBERTaClassifier(len(label_cols)).to(device)


## 5: Fine‑Tune ChemBERTa

In [9]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))

CUDA available: True
GPU name: NVIDIA GeForce RTX 4070 Ti


In [None]:
# ---- Simple training loop: 20 epochs, early stopping, metrics printed ----
import torch, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import roc_auc_score, f1_score
import numpy as np, tqdm, time

EPOCHS      = 20
BATCH_SIZE  = 16
LR          = 2e-5
PATIENCE    = 3          # stop if macro AUC doesn't improve for PATIENCE epochs
THRESH      = 0.5        # for F1/accuracy
use_amp     = (device.type == "cuda")

# Data loaders
train_ds = TensorDataset(enc_train.input_ids,
                         enc_train.attention_mask,
                         torch.FloatTensor(y_train.values))
val_ds   = TensorDataset(enc_val.input_ids,
                         enc_val.attention_mask,
                         torch.FloatTensor(y_val.values))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)

# Optimizer
opt = optim.AdamW(model.parameters(), lr=LR)

# AMP helpers
try:
    scaler  = torch.amp.GradScaler("cuda") if use_amp else None
    autocast_ctx = lambda: torch.amp.autocast("cuda", enabled=use_amp)
except AttributeError:
    from torch.cuda.amp import GradScaler, autocast
    scaler  = GradScaler(enabled=use_amp)
    autocast_ctx = lambda: autocast(enabled=use_amp)

def eval_loop(loader):
    model.eval()
    ys, ps = [], []
    with torch.no_grad(), autocast_ctx():
        for ids, attn, yb in loader:
            ids, attn, yb = ids.to(device), attn.to(device), yb.to(device)
            logits = model(ids, attn)
            probs  = torch.sigmoid(logits).cpu().numpy()
            ys.append(yb.cpu().numpy())
            ps.append(probs)
    ys, ps = np.vstack(ys), np.vstack(ps)
    # metrics
    try:
        macro_auc = np.nanmean([roc_auc_score(ys[:,i], ps[:,i]) 
                                if len(np.unique(ys[:,i]))>1 else np.nan
                                for i in range(ps.shape[1])])
    except ValueError:
        macro_auc = np.nan
    f1_macro = f1_score(ys, (ps>THRESH).astype(int), average='macro', zero_division=0)
    acc = ( (ps>THRESH) == ys ).mean()
    return macro_auc, f1_macro, acc

best_auc = -np.inf
no_improve = 0
model.train()

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    epoch_loss = 0.0
    seen = 0
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    opt.zero_grad(set_to_none=True)

    for ids, attn, yb in pbar:
        ids, attn, yb = ids.to(device), attn.to(device), yb.to(device)
        with autocast_ctx():
            logits = model(ids, attn)
            loss   = F.binary_cross_entropy_with_logits(logits, yb)
        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        else:
            loss.backward(); opt.step()
        opt.zero_grad(set_to_none=True)

        epoch_loss += loss.item() * ids.size(0)
        seen += ids.size(0)
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    # ---- validation ----
    val_auc, val_f1, val_acc = eval_loop(val_loader)
    print(f"Epoch {epoch} done in {time.time()-t0:.1f}s | "
          f"train_loss={epoch_loss/seen:.4f} | "
          f"valAUC={val_auc:.3f} | valF1={val_f1:.3f} | valAcc={val_acc:.3f}")

    # early stopping check
    if val_auc > best_auc + 1e-4:
        best_auc = val_auc
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("⛔ Early stopping (no AUC improvement).")
            break

model.eval()
print("Finished training.")


✅ Saved:
  • Weights        → models/v4\model.pt
  • Full checkpoint→ models/v4\checkpoints\full_checkpoint.pt
  • Tokenizer files→ models/v4
  • Metrics        → models/v4\metrics
  • Config         → models/v4\train_config.json


## 6: Save Model & Tokenizer 

In [None]:
import os

SAVE_DIR = "models/v4"
model.eval()
os.makedirs(SAVE_DIR, exist_ok=True)
torch.save(model.state_dict(), f"{SAVE_DIR}/model.pt")
tokenizer.save_pretrained(SAVE_DIR)
print("✅ Model & tokenizer saved.")

✅ Model & tokenizer saved.


## 7:  Compute SHAP Mean‑Abs Features (All 12 Classes)

In [14]:
# Cell 7 – Wide SHAP (GPU, chunked, timed)
import shap, torch, numpy as np, torch.nn as nn, os, time

# ----------------- config -----------------
SUB_N     = None      # None = all validation mols; or int like 2000
BG_N      = 32        # background size (trade-off speed/quality)
CHUNK     = 64        # how many mols at once through SHAP
SAVE_DIR  = "Data_v3/SHAP_val_full"
SAVE_TOKENS = False   # True → also save per-token SHAP (big files!)
os.makedirs(SAVE_DIR, exist_ok=True)

model.to(device).eval()

# 1) Slice validation set & build embeddings on GPU
ids_full = enc_val.input_ids.to(device)
if SUB_N: ids_full = ids_full[:SUB_N]
with torch.no_grad():
    embed_full = model.bert.embeddings(ids_full).float()     # [N, S, 768]

N, S, E = embed_full.shape
bg      = embed_full[:BG_N]                                  # background
n_cls   = len(label_cols)

# outputs
shap_means = torch.zeros((N, n_cls), device=device)
if SAVE_TOKENS:
    token_means = np.zeros((N, n_cls, S), dtype=np.float32)   # large!

t_all = time.time()
for cls in range(n_cls):
    head = nn.Sequential(nn.Identity(), nn.Linear(E, 1)).to(device)
    head[1].weight.data = model.classifier[1].weight.data[cls:cls+1]
    head[1].bias.data   = model.classifier[1].bias.data[cls:cls+1]

    expl = shap.DeepExplainer(head, bg)   # runs gradients on GPU internally

    per_class_vals = []
    t0 = time.time()
    for start in range(0, N, CHUNK):
        chunk = embed_full[start:start+CHUNK]
        vals  = expl.shap_values(chunk)[0]             # numpy [chunk,S,768]
        # mean over embedding dim (768) ⇒ token-level importance
        tok_imp = np.abs(vals).mean(axis=2)            # [chunk,S]
        if SAVE_TOKENS:
            token_means[start:start+CHUNK, cls, :tok_imp.shape[1]] = tok_imp
        # mean over tokens & embed ⇒ single scalar per mol
        mol_imp = tok_imp.mean(axis=1)                 # [chunk]
        per_class_vals.append(mol_imp)

    shap_means[:, cls] = torch.from_numpy(np.concatenate(per_class_vals)).to(device)

    print(f"Class {cls+1}/{n_cls} ({label_cols[cls]}): {time.time()-t0:.1f}s")

print(f"⏱ Total SHAP time: {(time.time()-t_all)/60:.1f} min for N={N}")

# 2) Save
np.save(os.path.join(SAVE_DIR, "shap_means.npy"), shap_means.cpu().numpy())
if SAVE_TOKENS:
    np.save(os.path.join(SAVE_DIR, "token_means.npy"), token_means)

print("✅ Saved:",
      os.path.join(SAVE_DIR, "shap_means.npy"),
      "(and token_means.npy)" if SAVE_TOKENS else "")


Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.


Class 1/12 (NR-AR): 392.9s
Class 2/12 (NR-AR-LBD): 390.5s
Class 3/12 (NR-AhR): 390.3s
Class 4/12 (NR-Aromatase): 389.5s
Class 5/12 (NR-ER): 389.3s
Class 6/12 (NR-ER-LBD): 388.8s
Class 7/12 (NR-PPAR-gamma): 391.0s
Class 8/12 (SR-ARE): 391.0s
Class 9/12 (SR-ATAD5): 390.7s
Class 10/12 (SR-HSE): 389.7s
Class 11/12 (SR-MMP): 387.8s
Class 12/12 (SR-p53): 390.5s
⏱ Total SHAP time: 78.0 min for N=616
✅ Saved: Data_v3/SHAP_val_full\shap_means.npy 


## 8: Build Meta‑Explainer Dataset

In [15]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# 1️⃣ Load SHAP-means (produced in Cell 7-lite)
shap_means = np.load("Data_v3/SHAP_val_full/shap_means.npy")           # shape (SUB_N, 12)
N_sub = shap_means.shape[0]

# 2️⃣ Take the matching slice of descriptors & labels
desc_val_sub = desc_val.iloc[:N_sub].reset_index(drop=True)       # (SUB_N, d)
y_val_sub    = y_val.iloc[:N_sub].reset_index(drop=True)          # (SUB_N, 12)

# 3️⃣ Concatenate descriptor features + SHAP features
meta_X = np.hstack([desc_val_sub.values, shap_means])             # (SUB_N, d+12)
meta_y = y_val_sub.values                                         # (SUB_N, 12)

# 4️⃣ Train / test split
X_tr, X_te, y_tr, y_te = train_test_split(
    meta_X, meta_y, test_size=0.2, random_state=42
)

# 5️⃣ Build DataLoaders
tr_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(y_tr))
te_ds = TensorDataset(torch.FloatTensor(X_te), torch.FloatTensor(y_te))
tr_loader = DataLoader(tr_ds, batch_size=16, shuffle=True)
te_loader = DataLoader(te_ds, batch_size=16)

print(f"✅ Meta-Explainer tensors ready — X_tr: {X_tr.shape}, X_te: {X_te.shape}")

✅ Meta-Explainer tensors ready — X_tr: (492, 25), X_te: (124, 25)


## 9: Train Meta‑Explainer MLP

In [17]:
%pip install iterstrat

Note: you may need to restart the kernel to use updated packages.


ERROR: Could not find a version that satisfies the requirement iterstrat (from versions: none)
ERROR: No matching distribution found for iterstrat

[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [22]:
# Cell 9 – Meta‑Explainer with multilabel stratified split, class weighting, robust metrics
import numpy as np, torch, torch.nn as nn, torch.nn.functional as F, time
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

# ---------- 1) Stratified split (iterative multilabel); fallback if not installed ----------
try:
    from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
    mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    idx_tr, idx_te = next(mskf.split(meta_X, meta_y))
except ImportError:
    # very simple fallback: try to keep at least one positive per label in test
    rng = np.random.default_rng(42)
    idx = np.arange(len(meta_X)); rng.shuffle(idx)
    idx_tr, idx_te = [], []
    need = set(np.where(meta_y.sum(axis=0) > 0)[0])
    target_te = int(len(idx) * 0.2)
    for i in idx:
        if len(idx_te) < target_te and any(meta_y[i, j] == 1 for j in need):
            idx_te.append(i)
            need -= {j for j in np.where(meta_y[i] == 1)[0]}
        else:
            idx_tr.append(i)
    idx_te += list(need)
    idx_te = np.unique(idx_te).tolist()
    idx_tr = [i for i in idx if i not in idx_te]

X_tr, X_te = meta_X[idx_tr], meta_X[idx_te]
y_tr, y_te = meta_y[idx_tr], meta_y[idx_te]
print(f"Stratified shapes -> train: {X_tr.shape}, test: {X_te.shape}")

# ---------- 2) Drop labels with zero positives in whole split (optional but avoids NaNs) ----------
pos_counts = y_tr.sum(axis=0) + y_te.sum(axis=0)
keep_cols  = np.where(pos_counts > 0)[0]
if len(keep_cols) < y_tr.shape[1]:
    dropped = [label_cols[i] for i in range(len(label_cols)) if i not in keep_cols]
    print("⚠️ Dropping labels with no positives:", dropped)
    y_tr = y_tr[:, keep_cols]; y_te = y_te[:, keep_cols]
    label_cols_kept = [label_cols[i] for i in keep_cols]
else:
    label_cols_kept = label_cols

IN_DIM  = X_tr.shape[1]
OUT_DIM = y_tr.shape[1]

# ---------- 3) DataLoaders ----------
BATCH = 32
tr_ds = TensorDataset(torch.FloatTensor(X_tr), torch.FloatTensor(y_tr))
te_ds = TensorDataset(torch.FloatTensor(X_te), torch.FloatTensor(y_te))
tr_loader = DataLoader(tr_ds, batch_size=BATCH, shuffle=True)
te_loader = DataLoader(te_ds, batch_size=BATCH)

# ---------- 4) Model ----------
class MetaMLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, out_dim)
        )
    def forward(self, x): return self.net(x)

meta = MetaMLP(IN_DIM, OUT_DIM).to(device)

# ---------- 5) Loss with class imbalance handling ----------
pos = y_tr.sum(axis=0)
neg = y_tr.shape[0] - pos
pos_weight = torch.tensor((neg / np.clip(pos, 1, None)), dtype=torch.float32).to(device)
crit = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

opt      = torch.optim.Adam(meta.parameters(), lr=1e-3)
EPOCHS   = 200
PATIENCE = 12
THRESH   = 0.5

best_pr  = -np.inf
stalls   = 0
best_state   = None
best_metrics = None

# ---------- eval helper ----------
def eval_meta(loader):
    meta.eval()
    Ys, Ps = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            probs = torch.sigmoid(meta(xb)).cpu().numpy()
            Ys.append(yb.cpu().numpy()); Ps.append(probs)
    Ys, Ps = np.vstack(Ys), np.vstack(Ps)

    aucs, prs = [], []
    for i in range(OUT_DIM):
        y_i, p_i = Ys[:, i], Ps[:, i]
        if len(np.unique(y_i)) < 2:
            aucs.append(np.nan); prs.append(np.nan); continue
        aucs.append(roc_auc_score(y_i, p_i))
        prs.append(average_precision_score(y_i, p_i))
    macro_auc = np.nanmean(aucs)
    macro_pr  = np.nanmean(prs)
    f1_macro  = f1_score(Ys, (Ps > THRESH).astype(int), average='macro', zero_division=0)
    acc       = ((Ps > THRESH) == Ys).mean()
    return macro_auc, macro_pr, f1_macro, acc, aucs, prs

# ---------- training loop ----------
t0_all = time.time()
for ep in range(1, EPOCHS + 1):
    meta.train()
    loss_sum = 0.0; n = 0
    t0 = time.time()

    for xb, yb in tr_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = meta(xb)
        loss   = crit(logits, yb)
        opt.zero_grad(); loss.backward(); opt.step()
        loss_sum += loss.item() * xb.size(0); n += xb.size(0)

    val_auc, val_pr, val_f1, val_acc, aucs, prs = eval_meta(te_loader)
    print(f"Ep {ep:03d} | loss={loss_sum/n:.4f} | vPR={val_pr:.3f} vAUC={val_auc:.3f} "
          f"vF1={val_f1:.3f} vAcc={val_acc:.3f} | {time.time()-t0:.1f}s")

    if val_pr > best_pr + 1e-4:
        best_pr = val_pr; stalls = 0
        best_state   = meta.state_dict()
        best_metrics = (val_auc, val_pr, val_f1, val_acc, aucs, prs)
    else:
        stalls += 1
        if stalls >= PATIENCE:
            print("⛔ Early stopping (PR stalled).")
            break

print(f"Total meta training: {(time.time()-t0_all)/60:.1f} min")

# load best
if best_state is not None:
    meta.load_state_dict(best_state)
    meta.eval()
    val_auc, val_pr, val_f1, val_acc, aucs, prs = best_metrics
    print("\nBest macro metrics:",
          f"AUC={val_auc:.3f} PR={val_pr:.3f} F1={val_f1:.3f} Acc={val_acc:.3f}")
    for i, cls in enumerate(label_cols_kept):
        print(f"{cls:15s} AUC={aucs[i]:.3f}  PR={prs[i]:.3f}")
else:
    print("No improvement recorded; using last epoch weights.")


Stratified shapes -> train: (604, 25), test: (12, 25)
Ep 001 | loss=9.7579 | vPR=0.254 vAUC=0.423 vF1=0.117 vAcc=0.701 | 0.1s
Ep 002 | loss=6.1086 | vPR=0.267 vAUC=0.420 vF1=0.167 vAcc=0.597 | 0.1s
Ep 003 | loss=6.0653 | vPR=0.273 vAUC=0.442 vF1=0.202 vAcc=0.444 | 0.0s
Ep 004 | loss=3.8568 | vPR=0.259 vAUC=0.410 vF1=0.173 vAcc=0.500 | 0.0s
Ep 005 | loss=3.0642 | vPR=0.250 vAUC=0.395 vF1=0.173 vAcc=0.431 | 0.1s
Ep 006 | loss=3.9062 | vPR=0.261 vAUC=0.411 vF1=0.150 vAcc=0.493 | 0.0s
Ep 007 | loss=3.1602 | vPR=0.281 vAUC=0.451 vF1=0.257 vAcc=0.444 | 0.0s
Ep 008 | loss=3.5525 | vPR=0.278 vAUC=0.467 vF1=0.160 vAcc=0.333 | 0.0s
Ep 009 | loss=2.7341 | vPR=0.272 vAUC=0.463 vF1=0.170 vAcc=0.306 | 0.0s
Ep 010 | loss=2.9084 | vPR=0.301 vAUC=0.499 vF1=0.206 vAcc=0.410 | 0.0s
Ep 011 | loss=2.2476 | vPR=0.280 vAUC=0.494 vF1=0.230 vAcc=0.403 | 0.0s
Ep 012 | loss=2.5586 | vPR=0.291 vAUC=0.510 vF1=0.242 vAcc=0.424 | 0.0s
Ep 013 | loss=2.1041 | vPR=0.326 vAUC=0.517 vF1=0.241 vAcc=0.431 | 0.0s
Ep 014 | l

In [23]:
## save the MLP 
# Cell 10 – Save Meta‑Explainer (MLP) weights, metrics, and config
import os, json, numpy as np, torch
from datetime import datetime

SAVE_DIR = "models/v4/MLP"
os.makedirs(SAVE_DIR, exist_ok=True)

# 1) Weights
torch.save(meta.state_dict(), os.path.join(SAVE_DIR, "meta_mlp.pt"))

# 2) Metrics (from best_metrics defined in Cell 9)
val_auc, val_pr, val_f1, val_acc, aucs, prs = best_metrics
metrics = {
    "timestamp": datetime.now().isoformat(timespec="seconds"),
    "macro_auc": float(val_auc),
    "macro_pr":  float(val_pr),
    "macro_f1":  float(val_f1),
    "macro_acc": float(val_acc),
    "per_class_auc": [float(x) if x==x else None for x in aucs],  # NaN→None
    "per_class_pr":  [float(x) if x==x else None for x in prs],
    "labels": label_cols_kept,
}

with open(os.path.join(SAVE_DIR, "metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)

# 3) Config
config = {
    "input_dim": int(IN_DIM),
    "output_dim": int(OUT_DIM),
    "batch_size": BATCH,
    "epochs_trained": ep,
    "learning_rate": 1e-3,
    "patience": PATIENCE,
    "threshold": 0.5,
    "device": str(device),
}
with open(os.path.join(SAVE_DIR, "config.json"), "w") as f:
    json.dump(config, f, indent=2)

# 4) Per-class arrays as .npy (optional)
np.save(os.path.join(SAVE_DIR, "per_class_auc.npy"), np.array(aucs))
np.save(os.path.join(SAVE_DIR, "per_class_pr.npy"),  np.array(prs))

print("✅ Meta-MLP saved to:", SAVE_DIR)


✅ Meta-MLP saved to: models/v4/MLP


## 10: Generate an Explanation for a New SMILES

In [25]:
# Cell 10 – Interactive SMILES → PubChem lookup → Meta‑explanation (uses NEW models)
# ----------------------------------------------------------------------------------
import os, json, requests, shap, torch, numpy as np, torch.nn as nn
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski, Crippen, rdMolDescriptors

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- 0. Paths ----------
CHEMBERTA_DIR = "models/v4"   # where model.pt + tokenizer/ vocab.json etc. live
META_DIR      = "models/v4/MLP"               # where meta_mlp.pt & metrics.json live

# ---------- 1. Rebuild & load ChemBERTa classifier ----------
from transformers import AutoTokenizer, AutoModel

class ChemBERTaClassifier(nn.Module):
    def __init__(self, n_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.bert.config.hidden_size, n_labels)
        )
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        return self.classifier(pooled)

tokenizer = AutoTokenizer.from_pretrained(CHEMBERTA_DIR)

# Use the kept labels if you dropped some in Cell 9, otherwise original
labels_used = label_cols_kept if 'label_cols_kept' in globals() else label_cols

model = ChemBERTaClassifier(n_labels=len(labels_used)).to(device)
state = torch.load(os.path.join(CHEMBERTA_DIR, "model.pt"), map_location=device)
model.load_state_dict(state)
model.eval()
E = model.bert.config.hidden_size

# ---------- 2. Rebuild & load Meta-MLP ----------
class MetaMLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, out_dim), nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)

# 13 descriptors + 12 SHAP means = 25 (adjust if you changed)
meta_in_dim = 25
meta = MetaMLP(meta_in_dim, len(labels_used)).to(device)
meta.load_state_dict(torch.load(os.path.join(META_DIR, "meta_mlp.pt"), map_location=device))
meta.eval()

# ---------- 3. Helpers ----------
def pubchem_name_cid(smiles: str):
    url = ("https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/"
           f"{smiles}/property/Title,IUPACName,CID/JSON")
    try:
        js = requests.get(url, timeout=10).json()
        props = js["PropertyTable"]["Properties"][0]
        return (props.get("Title") or props.get("IUPACName")), props.get("CID")
    except Exception:
        return None, None

def calc_desc_flags(mol):
    return np.array([
        Descriptors.MolWt(mol),                       # 0
        Descriptors.MolLogP(mol),                     # 1
        Descriptors.TPSA(mol),                        # 2
        Lipinski.NumHDonors(mol),                     # 3
        Lipinski.NumHAcceptors(mol),                  # 4
        Lipinski.NumRotatableBonds(mol),              # 5
        Descriptors.RingCount(mol),                   # 6
        Descriptors.NumAromaticRings(mol),            # 7
        int(mol.HasSubstructMatch(Chem.MolFromSmarts("[N+](=O)[O-]"))),  # 8  nitro
        int(mol.HasSubstructMatch(Chem.MolFromSmarts("[OX2H]"))),        # 9  phenolic OH
        int(mol.HasSubstructMatch(Chem.MolFromSmarts("[CX3]=O"))),       # 10 carbonyl
        int(mol.HasSubstructMatch(Chem.MolFromSmarts("[NX3;H2,H1]"))),   # 11 amine
        int(any(a.GetSymbol() in ("Cl","Br","F","I") for a in mol.GetAtoms()))  # 12 halogen
    ], dtype=float)

def explain_smiles(smiles: str, *, top_k=2, p_thresh=0.5):
    # 0) metadata
    name, cid = pubchem_name_cid(smiles)

    # 1) descriptors
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalid SMILES.", None
    d = calc_desc_flags(mol)

    # 2) per-class mean |SHAP| via DeepExplainer (fast for a single mol)
    enc = tokenizer(smiles, return_tensors="pt").to(device)
    with torch.no_grad():
        emb = model.bert.embeddings(enc.input_ids)
    shap_vec = []
    for i in range(len(labels_used)):
        head = nn.Sequential(nn.Identity(), nn.Linear(E, 1)).to(device)
        head[1].weight.data = model.classifier[1].weight.data[i:i+1]
        head[1].bias.data   = model.classifier[1].bias.data[i:i+1]
        v = shap.DeepExplainer(head, emb).shap_values(emb)[0]  # [1,S,768]
        shap_vec.append(np.abs(v).mean())
    shap_vec = np.array(shap_vec)

    # 3) meta prediction
    feats = np.hstack([d, shap_vec])  # length must match meta_in_dim
    with torch.no_grad():
        probs = meta(torch.FloatTensor(feats).unsqueeze(0).to(device)).cpu().numpy()[0]

    # 4) textual justification
    positives = [labels_used[i] for i, p in enumerate(probs) if p > p_thresh]
    reasons = []
    if d[3] > 2:   reasons.append("high H‑bond donor count")
    if d[0] > 500: reasons.append("large MolWt")
    if d[8]:       reasons.append("nitro group")
    if d[9]:       reasons.append("phenolic hydroxyl")
    if d[12]:      reasons.append("halogen substituent")
    if not reasons:
        idxs = shap_vec.argsort()[-top_k:][::-1]
        reasons = [f"strong model signal for {labels_used[i]}" for i in idxs]

    header = ""
    if name: header += f"**{name}** "
    if cid:  header += f"(CID {cid}) "
    txt = f"{header}Model predicts toxicity for {', '.join(positives) if positives else 'no endpoints'} because of {' and '.join(reasons) if reasons else 'model signal patterns'}."
    return txt, probs

# ---------- 4. Run interactively ----------
smiles_in = input("Enter your drug SMILES: ").strip()
explanation, prob_vec = explain_smiles(smiles_in)

print("\n" + explanation)
if prob_vec is not None:
    import pandas as pd
    df_probs = pd.DataFrame({"Endpoint": labels_used, "Prob": np.round(prob_vec, 3)})
    print(df_probs.to_string(index=False))


Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.



Model predicts toxicity for NR-AhR, NR-Aromatase, NR-ER, SR-ARE because of strong model signal for SR-p53 and strong model signal for SR-MMP.
     Endpoint  Prob
        NR-AR 0.452
    NR-AR-LBD 0.327
       NR-AhR 0.553
 NR-Aromatase 0.538
        NR-ER 0.519
    NR-ER-LBD 0.244
NR-PPAR-gamma 0.096
       SR-ARE 0.544
     SR-ATAD5 0.000
       SR-HSE 0.283
       SR-MMP 0.452
       SR-p53 0.062
