# objective is to prepare our data from step 0 to step 3

## step 0 

**AIM:** 
* Multilabel classification with 3 specific ADE targets.
* Input = SMILES; Output = 3 binary labels.

**Specifying our label columns:**
* label_Gastrointestinal disorders
* label_Infections and infestations
* label_Nervous system disorders

**we need to Ensure input column (smiles) exists and is clean, thus:**

* Remove rows with missing or malformed SMILES.


**Inspect label distributions (class balance preview):**
* helps later with class weighting and threshold tuning.


**NOTE:** we are dropping everything else from the DataFrame, as they are not needed



In [1]:
import pandas as pd
from pathlib import Path

# Define input/output paths
input_dir = Path("Data/CT-ADE-SOC")
output_dir = Path("Data/clean")

# Target label columns to keep
target_cols = [
    "label_Gastrointestinal disorders",
    "label_Infections and infestations",
    "label_Nervous system disorders"
]

# Input and output mapping
file_map = {
    "train": ("train.csv", "clean_train.csv"),
    "val": ("val.csv", "clean_val.csv"),
    "test": ("test.csv", "clean_test.csv")
}

# Process each file
for split_name, (in_file, out_file) in file_map.items():
    print(f"\n🔹 Processing {split_name.upper()} set...")
    
    # Load and filter
    df = pd.read_csv(input_dir / in_file)
    df = df[['smiles'] + target_cols]
    df = df.dropna(subset=['smiles']).reset_index(drop=True)
    df[target_cols] = df[target_cols].astype('uint8')

    # Basic stats
    print(f"✅ Rows after cleaning: {len(df)}")
    print("Label distribution:\n", df[target_cols].sum())

    # Save cleaned file
    df.to_csv(output_dir / out_file, index=False)
    print(f"📁 Saved to: {output_dir / out_file}")



🔹 Processing TRAIN set...
✅ Rows after cleaning: 12419
Label distribution:
 label_Gastrointestinal disorders     4683
label_Infections and infestations    3303
label_Nervous system disorders       4206
dtype: uint64
📁 Saved to: Data\clean\clean_train.csv

🔹 Processing VAL set...
✅ Rows after cleaning: 1518
Label distribution:
 label_Gastrointestinal disorders     629
label_Infections and infestations    410
label_Nervous system disorders       514
dtype: uint64
📁 Saved to: Data\clean\clean_val.csv

🔹 Processing TEST set...
✅ Rows after cleaning: 1260
Label distribution:
 label_Gastrointestinal disorders     513
label_Infections and infestations    367
label_Nervous system disorders       522
dtype: uint64
📁 Saved to: Data\clean\clean_test.csv


## Step 1:

* Load every cleaned split from Data/clean/.

* Drop “all-zero” rows:, why? 
If a compound has none of the three labels, it doesn't help the classifier unless you need extra negatives. We’ll drop them now for a more balanced training set.


* Compute label counts per split — needed later for pos_weight in the loss.

* Persist curated files to Data/interim/ for downstream tokenisation.

* Save a YAML or JSON summary of the label statistics so you never lose track.



In [2]:
import pandas as pd
from pathlib import Path
import yaml

# Paths
clean_dir = Path("Data/clean")          # input from Step 0
interim_dir = Path("Data/interim")
interim_dir.mkdir(parents=True, exist_ok=True)

splits = {
    "train": "clean_train.csv",
    "val":   "clean_val.csv",
    "test":  "clean_test.csv"
}

target_cols = [
    "label_Gastrointestinal disorders",
    "label_Infections and infestations",
    "label_Nervous system disorders"
]

stats = {}  # hold label counts for YAML summary

for split, fname in splits.items():
    df = pd.read_csv(clean_dir / fname)

    # 1️⃣  Remove rows where all three labels are zero
    mask_nonzero = df[target_cols].sum(axis=1) > 0
    dropped = len(df) - mask_nonzero.sum()
    df = df[mask_nonzero].reset_index(drop=True)

    # 2️⃣  Capture label counts
    label_counts = df[target_cols].sum().to_dict()
    stats[split] = {
        "rows_after_drop": len(df),
        "rows_dropped_all_zero": int(dropped),
        "label_counts": {k: int(v) for k, v in label_counts.items()}
    }

    # 3️⃣  Save curated split
    out_path = interim_dir / f"ade_3lbl_{split}.csv"
    df.to_csv(out_path, index=False)
    print(f"✅ {split.upper()} saved to {out_path}  ({len(df)} rows)")

# 4️⃣  Persist stats for future reference
with open(interim_dir / "label_stats.yaml", "w") as fp:
    yaml.dump(stats, fp, default_flow_style=False)

print("\n📊 Label stats written to", interim_dir / "label_stats.yaml")


✅ TRAIN saved to Data\interim\ade_3lbl_train.csv  (6542 rows)
✅ VAL saved to Data\interim\ade_3lbl_val.csv  (803 rows)
✅ TEST saved to Data\interim\ade_3lbl_test.csv  (710 rows)

📊 Label stats written to Data\interim\label_stats.yaml


## Step 2: 

* Load curated CSVs from Data/interim/ade_3lbl_{train,val,test}.csv.

* Initialise ChemBERTa tokenizer

* Model: seyonec/ChemBERTa-zinc-250k-pubchem-1m.

* Parameters: max_length = 128, padding='max_length', truncation=True.

* Convert SMILES to token IDs (input_ids, optional attention_mask if desired).

* Extract label matrix (y = N × 3, dtype uint8).

* Save arrays to Data/processed/

* X_train.npy, y_train.npy, X_val.npy, y_val.npy, X_test.npy, y_test.npy.

* Optionally save attn_train.npy, etc. for attention masks.

* Log shapes so you can sanity-check later.

In [4]:
import pandas as pd
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer

# --- Config ---
interim_dir   = Path("Data/interim")
processed_dir = Path("Data/processed")
processed_dir.mkdir(parents=True, exist_ok=True)

splits = ["train", "val", "test"]
target_cols = [
    "label_Gastrointestinal disorders",
    "label_Infections and infestations",
    "label_Nervous system disorders"
]

# 🔄  Using the open-access checkpoint
model_name = "seyonec/ChemBERTa_zinc250k_v2_40k"
tokenizer  = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)

max_len = 128  # keep sequences short for speed

for split in splits:
    print(f"\n🔹 Tokenising {split.upper()} set …")

    # 1) Load curated CSV
    df = pd.read_csv(interim_dir / f"ade_3lbl_{split}.csv")

    # 2) Tokenise SMILES strings
    toks = tokenizer(
        df["smiles"].tolist(),
        max_length=max_len,
        padding="max_length",
        truncation=True,
        return_tensors="np"
    )

    X     = toks["input_ids"].astype(np.int32)     # (N, 128)
    attn  = toks["attention_mask"].astype(np.int8) # (N, 128) OPTIONAL
    y     = df[target_cols].values.astype(np.uint8)# (N, 3)

    # 3) Save arrays
    np.save(processed_dir / f"X_{split}.npy",     X)
    np.save(processed_dir / f"y_{split}.npy",     y)
    np.save(processed_dir / f"attn_{split}.npy",  attn)  # drop if not used

    print(f"✅ Saved: X_{split}.npy   {X.shape}")
    print(f"✅ Saved: y_{split}.npy   {y.shape}")
    print(f"✅ Saved: attn_{split}.npy {attn.shape}")

print("\n📁 All processed tensors now reside in", processed_dir.resolve())


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development



🔹 Tokenising TRAIN set …
✅ Saved: X_train.npy   (6542, 128)
✅ Saved: y_train.npy   (6542, 3)
✅ Saved: attn_train.npy (6542, 128)

🔹 Tokenising VAL set …
✅ Saved: X_val.npy   (803, 128)
✅ Saved: y_val.npy   (803, 3)
✅ Saved: attn_val.npy (803, 128)

🔹 Tokenising TEST set …
✅ Saved: X_test.npy   (710, 128)
✅ Saved: y_test.npy   (710, 3)
✅ Saved: attn_test.npy (710, 128)

📁 All processed tensors now reside in D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\Data\processed


## step 3: 

**Load tensors**

* X_train.npy, y_train.npy, attn_train.npy (and corresponding val).
* Shapes: X = (N, 128) token IDs, attn = (N, 128) masks, y = (N, 3). (it's in the code above)

**Compute class weights (to fight imbalance)**

* For each label: pos_weight = (N − pos) / pos.
* Feed into BCEWithLogitsLoss.

**Build model**

* Base: seyonec/ChemBERTa_zinc250k_v2_40k. it's our base line, could use other ones too
* Replace head with nn.Linear(hidden_size, 3), followed by sigmoid at inference.

**Training hyper-params**

* LR = 3 e-5, batch = 32, epochs = 5–8, weight-decay = 0.01.
* Scheduler: cosine with 10 % warm-up steps.
* Metric: per-label F1 + “none-positive” accuracy (all logits < thr).

**Validation each epoch**

* Track loss + macro F1; early-stop if no improvement ≥ 2 epochs.

**Save artefacts**

* Best checkpoint → models/chemberta_3lbl/.
* Store config.json, pytorch_model.bin, tokenizer files.

**(Optional) Threshold sweep on validation set (Step 4 later).**




In [5]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))


CUDA available: True
GPU name: NVIDIA GeForce RTX 4070 Ti


In [10]:
import torch, torch.nn as nn
from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
import numpy as np
from sklearn.metrics import f1_score

# ---------- Device ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🖥️  Using device:", device)

# ---------- Data ----------
proc_dir = Path("Data/processed")
X_train = np.load(proc_dir / "X_train.npy")
y_train = np.load(proc_dir / "y_train.npy")
attn_train = np.load(proc_dir / "attn_train.npy")

X_val = np.load(proc_dir / "X_val.npy")
y_val = np.load(proc_dir / "y_val.npy")
attn_val = np.load(proc_dir / "attn_val.npy")

to_t = lambda arr, dtype: torch.tensor(arr, dtype=dtype)

train_ds = TensorDataset(
    to_t(X_train, torch.long),  to_t(attn_train, torch.long),  to_t(y_train, torch.float32)
)
val_ds = TensorDataset(
    to_t(X_val, torch.long),    to_t(attn_val, torch.long),    to_t(y_val, torch.float32)
)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)

# ---------- Model ----------
model_name = "seyonec/ChemBERTa_zinc250k_v2_40k"
base = AutoModel.from_pretrained(model_name)
hidden = base.config.hidden_size
classifier = nn.Linear(hidden, 3)

class ChemClassifier(nn.Module):
    def __init__(self, base, head):
        super().__init__()
        self.base = base
        self.head = head
    def forward(self, ids, attn):
        out = self.base(input_ids=ids, attention_mask=attn)
        cls = out.last_hidden_state[:, 0, :]          # CLS token
        return self.head(cls)

model = ChemClassifier(base, classifier).to(device)

# ---------- Loss (with class weights) ----------
pos = y_train.sum(axis=0)
neg = len(y_train) - pos
pos_weight = torch.tensor(neg / pos, dtype=torch.float32, device=device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# ---------- Optimizer & scheduler ----------
optim = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
total_steps = len(train_loader) * 6
sched = get_cosine_schedule_with_warmup(
    optim, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
)

# ---------- Training ----------
best_f1, patience = 0, 0
for epoch in range(1, 10):
    model.train()
    for ids, attn, labels in train_loader:
        ids, attn, labels = ids.to(device), attn.to(device), labels.to(device)
        loss = criterion(model(ids, attn), labels)
        loss.backward()
        optim.step(); sched.step(); optim.zero_grad()

    # ---- Validation ----
    model.eval(); p_all, y_all = [], []
    with torch.no_grad():
        for ids, attn, labels in val_loader:
            ids, attn = ids.to(device), attn.to(device)
            p_all.append(torch.sigmoid(model(ids, attn)).cpu())
            y_all.append(labels)
    preds = torch.cat(p_all).numpy()
    gts   = torch.cat(y_all).numpy()
    f1 = f1_score(gts, preds > 0.5, average="macro")
    print(f"Epoch {epoch} | val macro-F1={f1:.4f}")

    # ---- Early-stopping ----
    if f1 > best_f1:
        best_f1, patience = f1, 0
        torch.save(model.state_dict(), "models/chemberta_3lbl/pytorch_model.bin")
        print("  ✅  New best, model saved.")
    else:
        patience += 1
        if patience == 2:
            print("  ⏹️  Early stop.")
            break

print(f"\n🎯  Best val macro-F1: {best_f1:.4f}")


🖥️  Using device: cuda
Epoch 1 | val macro-F1=0.5534
  ✅  New best, model saved.
Epoch 2 | val macro-F1=0.6410
  ✅  New best, model saved.
Epoch 3 | val macro-F1=0.5882
Epoch 4 | val macro-F1=0.6233
  ⏹️  Early stop.

🎯  Best val macro-F1: 0.6410


## Step 4:

* Reload the best checkpoint (weights + tokenizer).

* Run inference on the validation tensors to collect raw logits (before sigmoid).

* Convert logits → probabilities with sigmoid.

* Sweep thresholds from 0.05 to 0.50 (step 0.01) for each label independently.
    * Compute balanced accuracy = ½ ( TPR + TNR ).
    * Keep the threshold giving the highest balanced accuracy.

* Persist thresholds to models/chemberta_3lbl/thresholds.json.

* Quick report: print the chosen threshold and val metrics (F1, bal-acc) for each label.

* These thresholds will be used at inference; if all three probs < their thresholds we return “none-of-three”.

In [16]:
import torch, json, numpy as np, torch.nn as nn
from sklearn.metrics import balanced_accuracy_score, f1_score
from transformers import AutoModel
from pathlib import Path

# --- paths ---
proc_dir  = Path("Data/processed")
model_dir = Path("models/chemberta_3lbl")
ckpt_path = model_dir / "pytorch_model.bin"
thresh_out = model_dir / "thresholds.json"

# --- data ---
X_val   = np.load(proc_dir / "X_val.npy")
attn_val = np.load(proc_dir / "attn_val.npy")
y_val   = np.load(proc_dir / "y_val.npy")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- rebuild architecture exactly as during training ---
base_name = "seyonec/ChemBERTa_zinc250k_v2_40k"
base = AutoModel.from_pretrained(base_name)
hidden = base.config.hidden_size
classifier = nn.Linear(hidden, 3)
class ChemClassifier(nn.Module):
    def __init__(self, base, classifier):
        super().__init__()
        self.base = base
        self.classifier = classifier
    def forward(self, ids, attn):
        cls = self.base(input_ids=ids, attention_mask=attn).last_hidden_state[:,0,:]
        return self.classifier(cls)

model = ChemClassifier(base, classifier).to(device)

# --- load state dict, renaming head.* -> classifier.* ---
state = torch.load(ckpt_path, map_location="cpu")
if "head.weight" in state:
    state["classifier.weight"] = state.pop("head.weight")
    state["classifier.bias"]   = state.pop("head.bias")
model.load_state_dict(state, strict=True)
model.eval()

# --- collect logits on val ---
batch = 256
logits = []
with torch.no_grad():
    for i in range(0, len(X_val), batch):
        ids  = torch.tensor(X_val[i:i+batch], dtype=torch.long, device=device)
        att  = torch.tensor(attn_val[i:i+batch], dtype=torch.long, device=device)
        logits.append(model(ids, att).cpu())
logits = torch.cat(logits).numpy()
probs  = 1/(1+np.exp(-logits))

# --- sweep thresholds ---
thr_range = np.arange(0.05,0.51,0.01)
best_thr, best_bal, best_f1 = [], [], []
for c in range(3):
    bal = [balanced_accuracy_score(y_val[:,c], probs[:,c]>=t) for t in thr_range]
    f1  = [f1_score(y_val[:,c], probs[:,c]>=t)               for t in thr_range]
    idx = int(np.argmax(bal))
    best_thr.append(float(thr_range[idx])); best_bal.append(float(bal[idx])); best_f1.append(float(f1[idx]))

labels = ["Gastrointestinal","Infections","NervousSystem"]
json.dump({k:v for k,v in zip(labels,best_thr)}, open(thresh_out,"w"), indent=2)

for l,t,ba,f in zip(labels,best_thr,best_bal,best_f1):
    print(f"{l:15s}  thr={t:.2f} | bal-acc={ba:.3f} | F1={f:.3f}")
print("\n💾 Thresholds saved to", thresh_out)


Gastrointestinal  thr=0.48 | bal-acc=0.584 | F1=0.775
Infections       thr=0.33 | bal-acc=0.593 | F1=0.673
NervousSystem    thr=0.50 | bal-acc=0.523 | F1=0.685

💾 Thresholds saved to models\chemberta_3lbl\thresholds.json


## Step 5: 

* Reload the frozen ChemBERTa classifier (weights + classifier renaming trick).

* Pick a compact background set for SHAP DeepExplainer
    * Randomly select 100 diverse SMILES from train (stratified by label).

* Build the SHAP explainer
    * shap.DeepExplainer(model, background_embeddings) where
      background_embeddings = base(**tokenised_bg).last_hidden_state[:,0,:].
    * We pass embeddings instead of token IDs to cut compute time.

* Batch-compute SHAP values for each split
    * Loop over X_* tensors (batch 128).
    * For each batch:
        * Feed token IDs & masks → logits.
        * explainer.shap_values((ids, attn)) returns 3 × B × 128 arrays.
    * Store alongside y_* and token IDs into .npz files: data/shap/shap_train.npz, shap_val.npz, shap_test.npz.

**Memory tips** 
* Process on GPU but move finished SHAP arrays to CPU before saving.
* Save in float16 (astype(np.float16)) to shrink disk usage.


In [24]:
import shap, torch, numpy as np
from transformers import AutoModel
from pathlib import Path
import random

# ---------------- paths ----------------
proc_dir = Path("Data/processed")
shap_dir = Path("Data/shap_clean"); shap_dir.mkdir(exist_ok=True, parents=True)

# ---------------- rebuild model ----------------
base_name = "seyonec/ChemBERTa_zinc250k_v2_40k"
base = AutoModel.from_pretrained(base_name)
head = torch.nn.Linear(base.config.hidden_size, 3)

class Net(torch.nn.Module):
    def __init__(self, b, h): super().__init__(); self.base, self.head = b, h
    def forward(self, ids, att): return self.head(
        self.base(input_ids=ids, attention_mask=att).last_hidden_state[:,0,:])

model = Net(base, head)
ckpt = torch.load("models/chemberta_3lbl/pytorch_model.bin", map_location="cpu")
model.load_state_dict(ckpt, strict=False)
model.eval().cuda()

@torch.no_grad()
def embed(ids, att):
    return base(input_ids=ids, attention_mask=att).last_hidden_state[:,0,:]

# ---------------- background ----------------
Xtr = np.load(proc_dir/"X_train.npy"); Att = np.load(proc_dir/"attn_train.npy")
bg = embed(torch.tensor(Xtr[random.sample(range(len(Xtr)),100)],
                        device="cuda"),
           torch.tensor(Att[random.sample(range(len(Xtr)),100)],
                        device="cuda"))
explainer = shap.GradientExplainer(model.head, [bg])

# ---------------- cache exactly (3,N,768) ----------------
def cache(name):
    ids = np.load(proc_dir/f"X_{name}.npy")
    att = np.load(proc_dir/f"attn_{name}.npy")
    y   = np.load(proc_dir/f"y_{name}.npy")

    shap_full = np.empty((3, len(ids), 768), dtype=np.float16)  # pre-alloc
    batch = 128; idx = 0
    for start in range(0, len(ids), batch):
        end = start + batch
        emb = embed(torch.tensor(ids[start:end], device="cuda"),
                    torch.tensor(att[start:end], device="cuda"))
        sv  = explainer.shap_values([emb])          # list len=3, each (B,768)
        for lbl in range(3):
            shap_full[lbl, idx:idx+emb.size(0)] = sv[lbl].astype(np.float16)
        idx += emb.size(0)

    np.savez(shap_dir/f"shap_{name}.npz", shap=shap_full, y=y)
    print(f"✅ saved {name}: {shap_full.shape}")

for split in ["train","val","test"]:
    cache(split)


✅ saved train: (3, 6542, 768)
✅ saved val: (3, 803, 768)
✅ saved test: (3, 710, 768)


## Step 6: 

* Normalise SHAP per sample
    * For each label ℓ, divide the |SHAP| vector by its L1 sum → relative importance.

* Select top-k CLS dimensions (k = 5)
    * For every label pick the k indices with highest mean |SHAP| across train.
    * Fix this index list so train/val/test line up.

* Build feature matrix
    * For every sample: concatenate the 3 labels × k values → 15-dim vector.
    * Shape after stack: Xmeta_train = (Ntrain, 15).

**Save artefacts**
* Data/meta/Xmeta_{split}.npy – float32
* Data/meta/ymeta_{split}.npy – same y matrix (uint8)
* Data/meta/topk_indices.json – the 3 × k CLS dimension IDs.

Why this works

* CLS dims capture global chemical context; top-k dims act like high-level “topics”.
* The Meta-MLP learns consistent patterns across splits.



In [29]:
import numpy as np
from pathlib import Path

meta = Path("Data/meta")

for split in ["train", "val", "test"]:
    X = np.load(meta / f"Xmeta_{split}.npy")          # (5,  N×3)
    y = np.load(meta / f"ymeta_{split}.npy")          # (N, 3)

    if X.shape[0] == 5:                               # the “wrong” layout
        N = y.shape[0]                                # true sample count
        X_fixed = X.T.reshape(N, 3, 5).reshape(N, 15) # (N,15)
        np.save(meta / f"Xmeta_{split}.npy", X_fixed.astype(np.float32))
        print(f"🔄 fixed {split}: {X_fixed.shape}")
    else:
        print(f"✅ {split} already OK: {X.shape}")


🔄 fixed train: (6542, 15)
🔄 fixed val: (803, 15)
🔄 fixed test: (710, 15)


## Step 7: 

* Standardise the 15-dim features (z-score) and persist the scaler (scaler.pkl).

* Define a compact MLP: 15 → 512 → Dropout(0.2) → 3 (sigmoid).

* Use class-weighted BCEWithLogitsLoss to balance positives vs negatives.

* Train for up to 100 epochs with early-stopping (patience = 10) on val macro-F1.

* Save best weights to models/meta_mlp.pt plus the scaler for inference.

* Report final val F1 and a quick test F1.


In [31]:
import numpy as np, torch, torch.nn as nn, pickle, os
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
from pathlib import Path

# ---------- paths ----------
meta_dir  = Path("Data/meta")
model_dir = Path("models"); model_dir.mkdir(exist_ok=True)

# ---------- load data ----------
X_tr = np.load(meta_dir/"Xmeta_train.npy")
y_tr = np.load(meta_dir/"ymeta_train.npy")
X_va = np.load(meta_dir/"Xmeta_val.npy")
y_va = np.load(meta_dir/"ymeta_val.npy")
X_te = np.load(meta_dir/"Xmeta_test.npy")
y_te = np.load(meta_dir/"ymeta_test.npy")

# ---------- standardise ----------
scaler = StandardScaler().fit(X_tr)
X_tr = scaler.transform(X_tr).astype(np.float32)
X_va = scaler.transform(X_va).astype(np.float32)
X_te = scaler.transform(X_te).astype(np.float32)
pickle.dump(scaler, open(model_dir/"scaler.pkl", "wb"))

# ---------- tensors ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
to_t  = lambda x: torch.tensor(x, device=device)
Xtr_t, ytr_t = to_t(X_tr), to_t(y_tr).float()
Xva_t, yva_t = to_t(X_va), to_t(y_va).float()
Xte_t, yte_t = to_t(X_te), to_t(y_te).float()

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(Xtr_t, ytr_t),
    batch_size=128, shuffle=True
)

# ---------- model ----------
mlp = nn.Sequential(
    nn.Linear(15, 512),
    nn.GELU(),
    nn.Dropout(0.2),
    nn.Linear(512, 3)           # 3 explanation classes (one per ADE label for now)
).to(device)

# ---------- loss ----------
pos = y_tr.sum(axis=0); neg = len(y_tr) - pos
pos_weight = torch.tensor(neg/pos, dtype=torch.float32, device=device)
criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
opt = torch.optim.AdamW(mlp.parameters(), lr=1e-3)

# ---------- training ----------
best_f1, patience = 0, 0
for epoch in range(1, 101):
    mlp.train()
    for xb, yb in train_loader:
        loss = criterion(mlp(xb), yb)
        loss.backward(); opt.step(); opt.zero_grad()

    mlp.eval()
    with torch.no_grad():
        pred_va = torch.sigmoid(mlp(Xva_t)).cpu().numpy()
    f1_va = f1_score(y_va, pred_va > 0.5, average="macro")

    print(f"Epoch {epoch:03d} | val F1 = {f1_va:.3f}")

    if f1_va > best_f1:
        best_f1, patience = f1_va, 0
        torch.save(mlp.state_dict(), model_dir/"meta_mlp.pt")
        print("  ✅ new best saved")
    else:
        patience += 1
        if patience == 10:
            print("  ⏹️ early stop")
            break

# ---------- quick test score ----------
mlp.load_state_dict(torch.load(model_dir/"meta_mlp.pt"))
mlp.eval()
with torch.no_grad():
    pred_te = torch.sigmoid(mlp(Xte_t)).detach().cpu().numpy()

f1_te = f1_score(y_te, pred_te > 0.5, average="macro")
print(f"\n🎯 best val F1: {best_f1:.3f} | test F1: {f1_te:.3f}")


Epoch 001 | val F1 = 0.499
  ✅ new best saved
Epoch 002 | val F1 = 0.540
  ✅ new best saved
Epoch 003 | val F1 = 0.586
  ✅ new best saved
Epoch 004 | val F1 = 0.455
Epoch 005 | val F1 = 0.578
Epoch 006 | val F1 = 0.584
Epoch 007 | val F1 = 0.539
Epoch 008 | val F1 = 0.516
Epoch 009 | val F1 = 0.508
Epoch 010 | val F1 = 0.534
Epoch 011 | val F1 = 0.549
Epoch 012 | val F1 = 0.558
Epoch 013 | val F1 = 0.545
  ⏹️ early stop

🎯 best val F1: 0.586 | test F1: 0.604


## Step 8: Testing!


In [41]:
# ===== Interactive ADE Predictor with Ontology-Driven Explanations =====
import torch, pickle, json, yaml, numpy as np
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import MolToSmiles
import shap, warnings; warnings.filterwarnings("ignore")

# ───────────────────── 1.  Load artefacts ─────────────────────
base_name = "seyonec/ChemBERTa_zinc250k_v2_40k"
tokenizer = AutoTokenizer.from_pretrained(base_name)

base  = AutoModel.from_pretrained(base_name)
head  = torch.nn.Linear(base.config.hidden_size, 3)

class ADEModel(torch.nn.Module):
    def __init__(self, backbone, classifier):
        super().__init__(); self.backbone, self.classifier = backbone, classifier
    def forward(self, ids, attn):
        cls = self.backbone(input_ids=ids,
                            attention_mask=attn).last_hidden_state[:,0,:]
        return self.classifier(cls)

net = ADEModel(base, head)
net.load_state_dict(torch.load("models/chemberta_3lbl/pytorch_model.bin",
                               map_location="cpu"), strict=False)
net.eval()

thr_json = json.load(open("models/chemberta_3lbl/thresholds.json"))
thr_vec  = np.array([thr_json["Gastrointestinal"],
                     thr_json["Infections"],
                     thr_json["NervousSystem"]], dtype=np.float32)

scaler = pickle.load(open("models/scaler.pkl","rb"))
topk   = json.load(open("Data/meta/topk_indices.json"))

meta_mlp = torch.nn.Sequential(
    torch.nn.Linear(15,512), torch.nn.GELU(), torch.nn.Dropout(0.2),
    torch.nn.Linear(512,3)
)
meta_mlp.load_state_dict(torch.load("models/meta_mlp.pt", map_location="cpu"))
meta_mlp.eval()

# SHAP explainer (head wrt CLS embedding)
proc = Path("Data/processed")
bg_ids  = np.load(proc/"X_train.npy")[:100]
bg_attn = np.load(proc/"attn_train.npy")[:100]
with torch.no_grad():
    bg_emb = base(input_ids=torch.tensor(bg_ids),
                  attention_mask=torch.tensor(bg_attn)).last_hidden_state[:,0,:]
explainer = shap.GradientExplainer(head, [bg_emb])

labels = ["Gastrointestinal","Infections","NervousSystem"]

# ─── NEW: load feature-ontology map ───
ont_map = yaml.safe_load(open("ontology_map.yaml"))

# helper to build contextual sentence
def contextual_why(label:str, score:float)->str:
    drug, info = next(iter(yaml.safe_load(open("explanatory_map.yaml"))[label].items()))
    f1, f2 = info["synergy"]
    pathway = info["pathway"]; ref = info["reference"]

    # ontology look-ups for each feature
    o1 = ont_map.get(f1, {})
    o2 = ont_map.get(f2, {})
    tag1 = f"{o1.get('source','')}: {o1.get('id','')}"
    tag2 = f"{o2.get('source','')}: {o2.get('id','')}"

    strength = ("Strong evidence" if score>0.7 else
                "Moderate evidence" if score>0.5 else
                "Weak evidence")

    return (f"{strength}. '{f1}' ({tag1}) synergises with '{f2}' ({tag2}) "
            f"in {pathway} ({drug}; {ref}).")

# ───────────────────── 2.  feature vector ─────────────────────
def feature_vector(smiles:str):
    tok = tokenizer(smiles, return_tensors="pt", max_length=128,
                    truncation=True, padding="max_length")
    with torch.no_grad():
        emb = base(**tok).last_hidden_state[:,0,:]
    sv = np.abs(np.stack(explainer.shap_values([emb]), axis=0))[:,0,:]
    sv /= sv.sum(-1, keepdims=True)+1e-8
    feat = np.concatenate([sv[l, topk[l]] for l in range(3)])
    return scaler.transform(feat.reshape(1,-1)).astype(np.float32)

# ───────────────────── 3.  Main loop (unchanged) ─────────────────────
while True:
    smi_input = input("\nEnter SMILES (or 'q' to quit): ").strip()
    if smi_input.lower()=="q": break
    mol = Chem.MolFromSmiles(smi_input)
    if mol is None:
        print("⚠️  Invalid SMILES."); continue
    smi = MolToSmiles(mol)

    tok = tokenizer(smi, return_tensors="pt", max_length=128,
                    truncation=True, padding="max_length")
    with torch.no_grad():
        logits = net(tok["input_ids"], tok["attention_mask"]).numpy().squeeze()
    probs = 1/(1+np.exp(-logits))
    preds = (probs>=thr_vec).astype(int)

    print("\nPrediction:")
    for lab,p,pr,thr in zip(labels,preds,probs,thr_vec):
        print(f"  {lab:<15} prob={pr:.3f} → {'POS' if p else 'NEG'} (thr {thr:.2f})")

    feats = feature_vector(smi)
    with torch.no_grad():
        scores = torch.sigmoid(meta_mlp(torch.tensor(feats))).numpy().squeeze()

    print("\nMeta-explanations:")
    shown=False
    for lab,sc in zip(labels,scores):
        if sc>0.3:
            verdict=("Strong support" if sc>0.7 else
                     "Moderate support" if sc>0.5 else
                     "Weak support")
            print(f"  {lab:<15} {verdict:15} (score {sc:.2f})")
            print(f"    ↳ {contextual_why(lab,sc)}")
            shown=True
    if not shown:
        print("  (none above 0.3)")



Prediction:
  Gastrointestinal prob=0.543 → POS (thr 0.48)
  Infections      prob=0.350 → POS (thr 0.33)
  NervousSystem   prob=0.389 → NEG (thr 0.50)

Meta-explanations:
  Gastrointestinal Moderate support (score 0.53)
    ↳ Moderate evidence. 'acidic_group' (ChEBI: ChEBI:30879) synergises with 'arylacetic_core' (MeSH: MeSH:D000894) in COX-1 inhibition â†’ â†“ prostaglandin â†’ GI mucosa irritation (Diclofenac; PMID 10542302).
  Infections      Moderate support (score 0.50)
    ↳ Moderate evidence. 'glucocorticoid_core' (MeSH: MeSH:D005938) synergises with '11beta_OH' (: ) in Immunosuppression â†’ â†‘ infection susceptibility (Prednisolone; PMID 31162841).
  NervousSystem   Weak support    (score 0.40)
    ↳ Weak evidence. 'aromatic_amine' (MeSH: MeSH:D000608) synergises with 'trifluoromethyl' (ChEBI: ChEBI:51112) in Selective serotonin reuptake inhibition (Fluoxetine; PMID 33164745).
