## phase 5 (Inference)

### 1:  Inference (calibrated specialist ensemble) + test export

In [None]:
# === Cold-start Inference (checkpoint name-compatible) ===
import os, json, math
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn

# ---------------- Paths & basics ----------------
BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
DESC_DIR   = BASE / "data" / "descriptors"
MODEL_DIR  = BASE / "model"
CKPT_BEST  = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
ENS_DIR    = MODEL_DIR / "ensembles"
CAL_DIR    = MODEL_DIR / "calibration"

assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"
assert (PREP_DIR / "dataset_manifest.json").exists(), "Missing dataset manifest."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- Labels, temps, thresholds ----------------
ds_manifest = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABEL_NAMES = ds_manifest["labels"]
DESC_IN_DIM = ds_manifest["n_features"]  # 208

temps      = json.loads((CAL_DIR / "temps.json").read_text())
thresholds = json.loads((CAL_DIR / "thresholds.json").read_text())

# ---------------- Text encoder (ChemBERTa) ----------------
from transformers import AutoTokenizer, AutoModel

class ChemBERTaEncoder(nn.Module):
    def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)
        self.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
        self.ln = nn.LayerNorm(fusion_dim)
    def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
        enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                             max_length=max_length, add_special_tokens=add_special_tokens,
                             return_tensors="pt")
        input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
        toks = self.ln(self.proj(out))  # (B,L,256)
        return toks, attention_mask.to(dtype=torch.int32)

# ---------------- Graph encoder (match checkpoint names) ----------------
from rdkit import Chem as _Chem

ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]

def _one_hot(v, choices):
    z = [0]*len(choices)
    if v in choices:
        z[choices.index(v)] = 1
    return z

def _bucket_oh(v, lo, hi):
    buckets = list(range(lo, hi+1))
    o = [0]*(len(buckets)+1)
    idx = v - lo
    o[idx if 0 <= idx < len(buckets) else -1] = 1
    return o

def _atom_feat(atom):
    hybs = [
        _Chem.rdchem.HybridizationType.S, _Chem.rdchem.HybridizationType.SP,
        _Chem.rdchem.HybridizationType.SP2, _Chem.rdchem.HybridizationType.SP3,
        _Chem.rdchem.HybridizationType.SP3D, _Chem.rdchem.HybridizationType.SP3D2
    ]
    chir = [
        _Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        _Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        _Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        _Chem.rdchem.ChiralType.CHI_OTHER
    ]
    sym = atom.GetSymbol()
    feat = _one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
    feat += _bucket_oh(atom.GetDegree(), 0, 5)
    feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
    feat += (_one_hot(atom.GetHybridization(), hybs)+[0])  # +other
    feat += [int(atom.GetIsAromatic())]
    feat += [int(atom.IsInRing())]
    feat += _one_hot(atom.GetChiralTag(), chir)
    feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
    feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
    feat += [atom.GetMass()/200.0]
    return feat  # ~51 dims

def _smiles_to_graph(smi, max_nodes=128):
    mol = _Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms() == 0:
        return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)
    feats = [_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
    x = np.asarray(feats, dtype=np.float32)
    N = mol.GetNumAtoms()
    adj = np.zeros((N, N), dtype=np.float32)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        adj[i, j] = 1.0; adj[j, i] = 1.0
    if N > max_nodes:
        x = x[:max_nodes]; adj = adj[:max_nodes, :max_nodes]
    return x, adj

def _collate_graphs(smiles_batch, max_nodes=128):
    graphs = [_smiles_to_graph(s) for s in smiles_batch]
    Nmax = max([g[0].shape[0] for g in graphs] + [1])
    Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
    B = len(graphs)
    X = np.zeros((B, Nmax, Fnode), dtype=np.float32)
    A = np.zeros((B, Nmax, Nmax), dtype=np.float32)
    M = np.zeros((B, Nmax), dtype=np.int64)
    for i, (x, a) in enumerate(graphs):
        n = x.shape[0]
        if n == 0: continue
        X[i, :n, :] = x
        A[i, :n, :n] = a
        M[i, :n] = 1
    return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

class GINLayer(nn.Module):
    def __init__(self, h=256, p=0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(0.0))
        self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
    def forward(self, x, adj, mask):
        out = (1.0 + self.eps) * x + torch.matmul(adj, x)
        out = self.mlp(out)
        return out * mask.unsqueeze(-1).to(out.dtype)

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
        self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
        # IMPORTANT: name must be 'out_ln' to match checkpoint
        self.out_ln = nn.LayerNorm(hidden_dim)
    def forward(self, smiles_list: List[str], max_nodes=128):
        X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
        h = self.inp(X)
        for layer in self.layers:
            h = layer(h, A, M)
        return self.out_ln(h), M.to(dtype=torch.int32)

# ---------------- Fusion blocks ----------------
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
    return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=256, n_heads=4, p=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
        self.ln  = nn.LayerNorm(dim)
        self.do  = nn.Dropout(p)
    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        Q = text_tokens.transpose(0,1)   # (L,B,D)
        K = graph_nodes.transpose(0,1)   # (N,B,D)
        V = graph_nodes.transpose(0,1)
        kpm = (graph_mask == 0)          # (B,N) True where pad
        attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
        attn = attn.transpose(0,1)       # (B,L,D)
        return self.ln(text_tokens + self.do(attn))

class DescriptorMLP(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
            nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)

# IMPORTANT: name must be 'mlp' to match checkpoint ('shared_head.mlp.*')
class FusionClassifier(nn.Module):
    def __init__(self, dim=256, n_labels=12, p=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, fused_vec):
        return self.mlp(fused_vec)

class V7FusionModel(nn.Module):
    def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
        super().__init__()
        self.text_encoder=text_encoder
        self.graph_encoder=graph_encoder
        self.cross=CrossAttentionBlock(dim, n_heads, p)
        self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
        self.shared_head=FusionClassifier(dim, n_labels, p)
    def forward(self, smiles_list, desc_feats, return_intermediates=False):
        tt, tm = self.text_encoder(smiles_list, max_length=256)
        gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
        tt, tm, gn, gm, desc_feats = tt.to(device), tm.to(device), gn.to(device), gm.to(device), desc_feats.to(device)
        tta = self.cross(tt, tm, gn, gm)
        de  = self.desc_mlp(desc_feats)
        text_pool  = masked_mean(tta, tm, 1)
        graph_pool = masked_mean(gn,  gm, 1)
        fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
        logits = self.shared_head(fused)
        if return_intermediates:
            return logits, fused
        return logits

# ---------------- Build & load ----------------
text_encoder = ChemBERTaEncoder().to(device)
graph_encoder= GraphGINEncoder().to(device)
v7_shared    = V7FusionModel(text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, n_labels=len(LABEL_NAMES)).to(device)
ckpt = torch.load(CKPT_BEST, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()
print("✅ Loaded shared fusion model.")

# ---------------- Specialist heads (match boosted Cell 2) ----------------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _load_best_head(label: str) -> nn.Module:
    # pick seed dir with highest best_ap
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ck["model"], strict=True)
    head.eval()
    return head

HEADS: Dict[str, nn.Module] = {lbl: _load_best_head(lbl) for lbl in LABEL_NAMES}
print("✅ Loaded specialist heads for all labels.")

# ---------------- Descriptors for ad-hoc SMILES ----------------
# For quick testing without the exact 208-d extractor, use standardized zero vector for descriptors.
def prepare_desc_matrix(smiles_list: List[str]) -> torch.Tensor:
    n = len(smiles_list)
    Z = np.zeros((n, DESC_IN_DIM), dtype=np.float32)  # standardized zeros (mean feature)
    return torch.tensor(Z, dtype=torch.float32, device=device)

# ---------------- Fused feature builder ----------------
@torch.no_grad()
def fused_from_smiles(smiles_list: List[str], desc_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
    if desc_tensor is None:
        desc_tensor = prepare_desc_matrix(smiles_list)
    tt, tm = v7_shared.text_encoder(smiles_list, max_length=256)
    gn, gm = v7_shared.graph_encoder(smiles_list, max_nodes=128)
    tt, tm = tt.to(device), tm.to(device)
    gn, gm = gn.to(device), gm.to(device)
    de = v7_shared.desc_mlp(desc_tensor.to(device))
    # cross-attend & pool
    tta = v7_shared.cross(tt, tm, gn, gm)
    text_pool  = masked_mean(tta, tm, 1)
    graph_pool = masked_mean(gn,  gm, 1)
    return torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)

# ---------------- Public API ----------------
def predict_smiles(smiles_list: List[str], threshold_mode: str = "fbeta15"):
    """
    Returns list[dict]: one per SMILES
      label -> {logit, prob_raw, prob_cal, decision}
    """
    assert threshold_mode in ("f1", "fbeta15")
    fused = fused_from_smiles(smiles_list)  # (B,768)
    out = []
    for i in range(fused.size(0)):
        row = {}
        x = fused[i:i+1]
        for label in LABEL_NAMES:
            head = HEADS[label]
            with torch.no_grad():
                logit = head(x).item()
            T   = max(float(temps.get(label, 1.0)), 1e-3)
            p_r = 1.0 / (1.0 + math.e**(-logit))
            p_c = 1.0 / (1.0 + math.e**(-logit / T))
            th  = thresholds[label]["th_fbeta15"] if threshold_mode=="fbeta15" else thresholds[label]["th_f1"]
            row[label] = {"logit": float(logit), "prob_raw": float(p_r), "prob_cal": float(p_c), "decision": bool(p_c >= float(th))}
        out.append(row)
    return out

print("✅ Inference is ready: call predict_smiles(['CCO'], threshold_mode='fbeta15' or 'f1').")

✅ Loaded shared fusion model.
✅ Loaded specialist heads for all labels.
✅ Inference is ready: call predict_smiles(['CCO'], threshold_mode='fbeta15' or 'f1').


In [None]:
# my_smiles = ["CCOc1ccc2nc(S(N)(=O)=O)sc2c1"]
# mode = "f1"  # or "f1" fbeta15

# results = predict_smiles(my_smiles, threshold_mode=mode)

# from operator import itemgetter
# for smi, rec in zip(my_smiles, results):
#     print("\nSMILES:", smi)
#     top = sorted([(lbl, d["prob_cal"], d["decision"]) for lbl, d in rec.items()],
#                  key=itemgetter(1), reverse=True)[:5]
#     for lbl, p, dec in top:
#         th = thresholds[lbl]["th_fbeta15"] if mode=="fbeta15" else thresholds[lbl]["th_f1"]
#         print(f"  {lbl:12s}  prob={p:.3f}  th={th:.3f}  → pred={int(dec)}")




# Ad-hoc evaluation on Excel truth labels (simple)
import pandas as pd
import numpy as np
from pathlib import Path
from operator import itemgetter
import math, json, os

# ----------- CONFIG -----------
EXCEL_PATH = Path("tox21_dualenc_v1/data/raw/Truth Lables.xlsx")
MODE = "f1"            # "f1" or "fbeta15"
N_DISPLAY = 5          # how many rows to pretty-print (set to None to print all)
OUT_CSV = Path("v7/results/inference/f1.csv")
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)

# ----------- Checks -----------
assert 'predict_smiles' in globals(), "predict_smiles() not found. Run the cold-start inference cell first."
assert 'LABEL_NAMES' in globals(), "LABEL_NAMES not found. Run the cold-start inference cell first."
assert 'thresholds' in globals(), "thresholds not found. Run Phase 4 calibration cell first."
assert EXCEL_PATH.exists(), f"Cannot find: {EXCEL_PATH}"

# ----------- Load Excel -----------
df = pd.read_excel(EXCEL_PATH)
cols_lower = {c.lower(): c for c in df.columns}
# find smiles col (case-insensitive)
smiles_col = None
for key in ["smiles", "smile", "SMILES", "Smiles"]:
    if key.lower() in cols_lower:
        smiles_col = cols_lower[key.lower()]
        break
if smiles_col is None:
    # fallback: first column named like 'smile*'
    cand = [c for c in df.columns if c.lower().startswith("smiles")]
    smiles_col = cand[0] if cand else None
assert smiles_col is not None, "Could not locate a SMILES column in the Excel file."

# ----------- Match label columns (case/spacing/hyphen-insensitive) -----------
def _norm(s: str) -> str:
    return "".join(ch for ch in str(s).lower() if ch.isalnum())

label_norm = { _norm(lbl): lbl for lbl in LABEL_NAMES }
col_for_label = {}  # label -> column name in df (if present)

for col in df.columns:
    if col == smiles_col: 
        continue
    n = _norm(col)
    if n in label_norm:
        col_for_label[label_norm[n]] = col

available_labels = [lbl for lbl in LABEL_NAMES if lbl in col_for_label]
missing_labels = [lbl for lbl in LABEL_NAMES if lbl not in col_for_label]
print(f"Found {len(available_labels)}/{len(LABEL_NAMES)} label columns in the Excel.")
if missing_labels:
    print("Missing label columns (will be skipped in scoring):", ", ".join(missing_labels))

# ----------- Parse truth values -----------
def parse_truth(v):
    if pd.isna(v): 
        return None
    if isinstance(v, (int, np.integer)): 
        return int(v) == 1
    if isinstance(v, float): 
        if math.isnan(v): return None
        return int(v) == 1
    s = str(v).strip().lower()
    if s in ("1","y","yes","true","t","pos","positive"):
        return True
    if s in ("0","n","no","false","f","neg","negative"):
        return False
    # anything else → None (unknown)
    return None

# ----------- Run predictions -----------
smiles_list = df[smiles_col].astype(str).tolist()
preds = predict_smiles(smiles_list, threshold_mode=MODE)  # list[dict[label -> details]]

# ----------- Build a simple evaluation table -----------
rows = []
micro_tp = micro_fp = micro_fn = 0

for i, (smi, rec) in enumerate(zip(smiles_list, preds)):
    # truth set (only for labels available in Excel)
    true_pos = set()
    true_neg = set()
    for lbl in available_labels:
        val = parse_truth(df.loc[i, col_for_label[lbl]])
        if val is True:
            true_pos.add(lbl)
        elif val is False:
            true_neg.add(lbl)
        # None → skip

    # predicted positives at chosen threshold
    pred_pos = {lbl for lbl, d in rec.items() if d["decision"]}
    # accumulate micro counts only on labels where truth is known
    for lbl in available_labels:
        val = parse_truth(df.loc[i, col_for_label[lbl]])
        if val is None: 
            continue
        if lbl in pred_pos and val is True:
            micro_tp += 1
        elif lbl in pred_pos and val is False:
            micro_fp += 1
        elif lbl not in pred_pos and val is True:
            micro_fn += 1

    # top-5 by calibrated probability (for pretty print)
    top5 = sorted([(lbl, d["prob_cal"], d["decision"]) for lbl, d in rec.items()],
                  key=itemgetter(1), reverse=True)[:5]

    # save a row for CSV: include probs & preds, and truths if present
    row = {"smiles": smi}
    for lbl, det in rec.items():
        row[f"{lbl}_prob"] = det["prob_cal"]
        row[f"{lbl}_pred"] = int(det["decision"])
        if lbl in available_labels:
            tv = parse_truth(df.loc[i, col_for_label[lbl]])
            row[f"{lbl}_true"] = (None if tv is None else int(tv))
    rows.append(row)

    # pretty print a few rows
    if N_DISPLAY is None or i < N_DISPLAY:
        print("\nSMILES:", smi)
        for lbl, p, dec in top5:
            th = thresholds[lbl]["th_fbeta15"] if MODE=="fbeta15" else thresholds[lbl]["th_f1"]
            print(f"  {lbl:12s}  prob={p:.3f}  th={float(th):.3f}  → pred={int(dec)}")
        if available_labels:
            print("  True positives:", ", ".join(sorted(true_pos)) if true_pos else "—")
            chosen = ", ".join(sorted(pred_pos)) if pred_pos else "—"
            print(f"  Pred positives ({MODE}): {chosen}")

# ----------- Micro summary -----------
prec = micro_tp / (micro_tp + micro_fp) if (micro_tp + micro_fp) > 0 else 0.0
rec  = micro_tp / (micro_tp + micro_fn) if (micro_tp + micro_fn) > 0 else 0.0
f1   = (2*prec*rec)/(prec+rec) if (prec+rec) > 0 else 0.0

print("\n=== Summary (micro over labels with truth present) ===")
print(f"TP={micro_tp} FP={micro_fp} FN={micro_fn}")
print(f"Precision={prec:.3f} Recall={rec:.3f} F1={f1:.3f}")

# ----------- Save CSV -----------
pd.DataFrame(rows).to_csv(OUT_CSV, index=False)
print(f"\nSaved detailed results → {OUT_CSV}")


Found 12/12 label columns in the Excel.

SMILES: CCOc1ccc2nc(S(N)(=O)=O)sc2c1
  NR-AhR        prob=0.594  th=0.681  → pred=0
  SR-ARE        prob=0.533  th=0.532  → pred=1
  NR-ER         prob=0.529  th=0.542  → pred=0
  SR-ATAD5      prob=0.529  th=0.567  → pred=0
  NR-PPAR-gamma  prob=0.521  th=0.521  → pred=1
  True positives: NR-AhR, SR-ARE
  Pred positives (f1): NR-PPAR-gamma, SR-ARE

SMILES: CCN1C(=O)NC(c2ccccc2)C1=O
  NR-AhR        prob=0.616  th=0.681  → pred=0
  SR-MMP        prob=0.554  th=0.600  → pred=0
  SR-ATAD5      prob=0.537  th=0.567  → pred=0
  NR-PPAR-gamma  prob=0.536  th=0.521  → pred=1
  SR-p53        prob=0.534  th=0.546  → pred=0
  True positives: —
  Pred positives (f1): NR-PPAR-gamma

SMILES: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1
  NR-AhR        prob=0.670  th=0.681  → pred=0
  SR-MMP        prob=0.641  th=0.600  → pred=1
  NR-ER-LBD     prob=0.580  th=0.650  → pred=0
  NR-Aromatase  prob=0.569  th=0.604  → pred=0
  SR-p53        prob=0.560  th=0.546  → pred

### 2: Calibrate shared head, create blended ensemble, refit thresholds

In [None]:
# Phase 5 — Cell 2 (optional): shared+specialist blend with calibration and new thresholds
import json, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score

BASE      = Path("v7")
FUSED_DIR = BASE / "data" / "fused"
CAL_DIR   = BASE / "model" / "calibration"
ENS_DIR   = BASE / "model" / "ensembles"
CAL_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Expect these in memory from earlier cold-start cell:
# v7_shared (with .shared_head), HEADS (specialists), LABEL_NAMES, temps (specialist temps)
assert 'v7_shared' in globals() and 'HEADS' in globals() and 'LABEL_NAMES' in globals() and 'temps' in globals()

# ---- load val fused + labels/mask ----
Xva = np.load(FUSED_DIR / "val_fused.npy")     # (N,768)
Yva = np.load(FUSED_DIR / "val_Y.npy")         # (N,12)
Mva = np.load(FUSED_DIR / "val_mask.npy")      # (N,12) True where missing

Xva_t = torch.tensor(Xva, dtype=torch.float32, device=device)

# ---- helper: fit per-label temperature (on logits) ----
def fit_temperature(logits: np.ndarray, y: np.ndarray, max_iter=200, lr=0.05) -> float:
    t = torch.tensor([1.0], dtype=torch.float32, requires_grad=True, device=device)
    x = torch.tensor(logits, dtype=torch.float32, device=device)
    y = torch.tensor(y,      dtype=torch.float32, device=device)
    opt = torch.optim.Adam([t], lr=lr)
    for _ in range(max_iter):
        opt.zero_grad(set_to_none=True)
        z = x / (t.clamp(min=1e-3))
        p = torch.sigmoid(z).clamp(1e-6, 1-1e-6)
        loss = - (y*torch.log(p) + (1-y)*torch.log(1-p)).mean()
        loss.backward(); opt.step()
    return float(t.detach().cpu().item())

def best_thresholds(y_true: np.ndarray, probs: np.ndarray):
    prec, rec, th = precision_recall_curve(y_true, probs)
    eps = 1e-8
    f1 = (2*prec*rec) / np.maximum(prec+rec, eps)
    beta = 1.5
    fb = ((1+beta**2)*prec*rec) / np.maximum((beta**2)*prec + rec, eps)
    th_f1 = th[np.nanargmax(f1[1:])] if th.size>0 else 0.5
    th_fb = th[np.nanargmax(fb[1:])] if th.size>0 else 0.5
    try:
        ap = float(average_precision_score(y_true, probs))
    except Exception:
        ap = float("nan")
    return {"th_f1": float(th_f1), "th_fbeta15": float(th_fb), "ap_val": ap}

# ---- 1) Calibrate SHARED head per label on val ----
print("Calibrating shared head temperatures on val...")
logits_shared = v7_shared.shared_head(Xva_t).detach().cpu().numpy()  # (N,12)
temps_shared = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mva[:, j]
    if valid.sum() == 0 or np.all(Yva[valid, j] == Yva[valid, j][0]):
        temps_shared[lbl] = 1.0
        continue
    T = fit_temperature(logits_shared[valid, j], Yva[valid, j])
    temps_shared[lbl] = T
    print(f"  {lbl}: T_shared={T:.3f}")
(Path(CAL_DIR / "temps_shared.json")).write_text(json.dumps(temps_shared, indent=2))
print("Saved →", CAL_DIR / "temps_shared.json")

# ---- 2) Build BLENDED probs on val (alpha specialist, (1-alpha) shared) ----
ALPHA = 0.8  # weight on specialist; tweak if desired
print(f"\nBlending probs on val with alpha={ALPHA:.2f} (specialist weight)")

# specialist logits on val
spec_logits = np.zeros_like(logits_shared)
with torch.no_grad():
    for j, lbl in enumerate(LABEL_NAMES):
        head = HEADS[lbl]
        spec_logits[:, j] = head(Xva_t).detach().cpu().numpy()

# calibrate both streams
p_spec_val   = np.zeros_like(spec_logits)
p_shared_val = np.zeros_like(logits_shared)
for j, lbl in enumerate(LABEL_NAMES):
    T_spec   = max(float(temps.get(lbl, 1.0)), 1e-3)
    T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
    p_spec_val[:, j]   = 1. / (1. + np.exp(-spec_logits[:, j]   / T_spec))
    p_shared_val[:, j] = 1. / (1. + np.exp(-logits_shared[:, j] / T_shared))

p_blend_val = ALPHA * p_spec_val + (1-ALPHA) * p_shared_val
p_blend_val = np.clip(p_blend_val, 0.0, 1.0)

# ---- 3) Refit thresholds for BLEND on val ----
thresholds_blend = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mva[:, j]
    if valid.sum() == 0 or np.all(Yva[valid, j] == Yva[valid, j][0]):
        thresholds_blend[lbl] = {"th_f1": 0.5, "th_fbeta15": 0.5, "ap_val": float("nan")}
        continue
    thresholds_blend[lbl] = best_thresholds(Yva[valid, j], p_blend_val[valid, j])
    print(f"  {lbl}: AP_val={thresholds_blend[lbl]['ap_val']:.3f} th_f1={thresholds_blend[lbl]['th_f1']:.3f} th_fb15={thresholds_blend[lbl]['th_fbeta15']:.3f}")

(Path(CAL_DIR / "thresholds_blend.json")).write_text(json.dumps({
    "alpha": ALPHA,
    "thresholds": thresholds_blend
}, indent=2))
print("\nSaved →", CAL_DIR / "thresholds_blend.json")

# ---- 4) Provide a convenience predictor using the BLEND (keep specialist predictor unchanged) ----
def predict_smiles_blend(smiles_list, mode: str = "fbeta15", alpha: float = ALPHA):
    """
    Returns list[dict]: per SMILES -> label -> {prob_spec, prob_shared, prob_blend, decision}
    """
    assert mode in ("f1","fbeta15")
    # fused features from shared encoders (desc branch is already wired)
    fused = fused_from_smiles(smiles_list)  # (B,768)
    out = []
    X = fused  # torch Tensor
    with torch.no_grad():
        logits_shared = v7_shared.shared_head(X).detach().cpu().numpy()
    for i in range(X.size(0)):
        row = {}
        xi = X[i:i+1]
        for j, lbl in enumerate(LABEL_NAMES):
            # specialist
            with torch.no_grad():
                logit_spec = HEADS[lbl](xi).item()
            T_spec   = max(float(temps.get(lbl, 1.0)), 1e-3)
            p_spec   = 1. / (1. + math.e**(-logit_spec / T_spec))
            # shared
            T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
            logit_sh = logits_shared[i, j]
            p_shared = 1. / (1. + math.e**(-logit_sh   / T_shared))
            # blend
            p_blend = alpha * p_spec + (1-alpha) * p_shared
            # threshold (use blended thresholds we just computed)
            th = thresholds_blend[lbl]["th_fbeta15"] if mode=="fbeta15" else thresholds_blend[lbl]["th_f1"]
            row[lbl] = {
                "prob_spec": float(p_spec),
                "prob_shared": float(p_shared),
                "prob_blend": float(p_blend),
                "decision": bool(p_blend >= float(th)),
            }
        out.append(row)
    return out

print("\n✅ Blend ready: use predict_smiles_blend([...], mode='fbeta15' or 'f1').")

Calibrating shared head temperatures on val...
  NR-AR: T_shared=0.134
  NR-AR-LBD: T_shared=0.132
  NR-AhR: T_shared=0.167
  NR-Aromatase: T_shared=0.126
  NR-ER: T_shared=0.134
  NR-ER-LBD: T_shared=0.110
  NR-PPAR-gamma: T_shared=0.167
  SR-ARE: T_shared=0.260
  SR-ATAD5: T_shared=0.146
  SR-HSE: T_shared=0.100
  SR-MMP: T_shared=0.250
  SR-p53: T_shared=0.119
Saved → v7\model\calibration\temps_shared.json

Blending probs on val with alpha=0.80 (specialist weight)
  NR-AR: AP_val=0.171 th_f1=0.653 th_fb15=0.653
  NR-AR-LBD: AP_val=0.253 th_f1=0.621 th_fb15=0.621
  NR-AhR: AP_val=0.524 th_f1=0.709 th_fb15=0.642
  NR-Aromatase: AP_val=0.295 th_f1=0.564 th_fb15=0.474
  NR-ER: AP_val=0.253 th_f1=0.547 th_fb15=0.480
  NR-ER-LBD: AP_val=0.139 th_f1=0.589 th_fb15=0.589
  NR-PPAR-gamma: AP_val=0.063 th_f1=0.441 th_fb15=0.427
  SR-ARE: AP_val=0.344 th_f1=0.528 th_fb15=0.528
  SR-ATAD5: AP_val=0.171 th_f1=0.483 th_fb15=0.483
  SR-HSE: AP_val=0.196 th_f1=0.472 th_fb15=0.459
  SR-MMP: AP_val=0.

### 3: Evaluate on test set & export CSV (choose specialist or blend)

In [None]:
# Phase 5 — Cell 3: Test export + quick metrics
import json, math
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.metrics import average_precision_score, precision_recall_curve

BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
FUSED_DIR  = BASE / "data" / "fused"
RESULTS_DIR= BASE / "results" / "inference"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
CAL_DIR    = BASE / "model" / "calibration"

# Choose which predictor to use:
USE_BLEND = True     # True → use predict_smiles_blend; False → use specialist-only predict_smiles
MODE      = "fbeta15"  # "fbeta15" or "f1"

# Load test blobs
blob = np.load(PREP_DIR / "test.npz", allow_pickle=True)
smiles = [str(s) for s in blob["smiles"].tolist()]
Yte    = blob["Y"].astype(np.float32)
Mte    = blob["y_missing_mask"].astype(bool)

# Also load fused for test to speed shared head for blend
Xte_fused = np.load(FUSED_DIR / "test_fused.npy") if (FUSED_DIR / "test_fused.npy").exists() else None

# Ensure thresholds for selected path
if USE_BLEND:
    data = json.loads((CAL_DIR / "thresholds_blend.json").read_text())
    thresholds_blend = data["thresholds"]
else:
    thresholds_spec = json.loads((CAL_DIR / "thresholds.json").read_text())

rows = []
probs_mat = np.zeros((len(smiles), len(LABEL_NAMES)), dtype=np.float32)

if USE_BLEND:
    # Compute via blend predictor
    preds = predict_smiles_blend(smiles, mode=MODE)
    for i, (smi, rec) in enumerate(zip(smiles, preds)):
        row = {"smiles": smi}
        for j, lbl in enumerate(LABEL_NAMES):
            p = rec[lbl]["prob_blend"]
            d = int(rec[lbl]["decision"])
            row[f"{lbl}_prob"] = p
            row[f"{lbl}_pred"] = d
            probs_mat[i, j] = p
        rows.append(row)
    out_csv = RESULTS_DIR / f"predictions_test_blend_{MODE}.csv"
else:
    # Specialist-only
    preds = predict_smiles(smiles, threshold_mode=MODE)
    for i, (smi, rec) in enumerate(zip(smiles, preds)):
        row = {"smiles": smi}
        for j, lbl in enumerate(LABEL_NAMES):
            p = rec[lbl]["prob_cal"]
            d = int(rec[lbl]["decision"])
            row[f"{lbl}_prob"] = p
            row[f"{lbl}_pred"] = d
            probs_mat[i, j] = p
        rows.append(row)
    out_csv = RESULTS_DIR / f"predictions_test_specialist_{MODE}.csv"

pd.DataFrame(rows).to_csv(out_csv, index=False)
print("✅ Saved:", out_csv)

# ---- Tiny metrics (test) ----
per_label_ap = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mte[:, j]
    if valid.sum() == 0 or np.all(Yte[valid, j] == Yte[valid, j][0]):
        per_label_ap[lbl] = float("nan"); continue
    try:
        per_label_ap[lbl] = float(average_precision_score(Yte[valid, j], probs_mat[valid, j]))
    except Exception:
        per_label_ap[lbl] = float("nan")

macro_pr = float(np.nanmean([v for v in per_label_ap.values()]))

# micro P/R/F1 using chosen thresholds
tp = fp = fn = 0
for i in range(len(smiles)):
    for j, lbl in enumerate(LABEL_NAMES):
        if Mte[i, j]: 
            continue
        truth = int(Yte[i, j])
        pred  = rows[i][f"{lbl}_pred"]
        tp += int(pred == 1 and truth == 1)
        fp += int(pred == 1 and truth == 0)
        fn += int(pred == 0 and truth == 1)

prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0

report = {
    "mode": ("blend" if USE_BLEND else "specialist"),
    "threshold_mode": MODE,
    "macro_pr_auc": macro_pr,
    "micro_precision": prec,
    "micro_recall": rec,
    "micro_f1": f1,
    "per_label_ap": per_label_ap
}
report_path = RESULTS_DIR / f"test_report_{'blend' if USE_BLEND else 'specialist'}_{MODE}.json"
report_path.write_text(json.dumps(report, indent=2))
print("\nSummary (test):")
print(json.dumps({k: (round(v,4) if isinstance(v, float) else v) for k,v in report.items() if k!='per_label_ap'}, indent=2))
print("Per-label AP saved in report JSON.")

✅ Saved: v7\results\inference\predictions_test_blend_fbeta15.csv

Summary (test):
{
  "mode": "blend",
  "threshold_mode": "fbeta15",
  "macro_pr_auc": 0.3208,
  "micro_precision": 0.2079,
  "micro_recall": 0.5734,
  "micro_f1": 0.3052
}
Per-label AP saved in report JSON.


### 4: test reg after cell 2& 3 (gave very strong results!)

In [None]:
# === V7: Single-SMILES/SMARTS Test Rig (BLENDED: specialist + shared) ===
# Uses:
#   v7/model/checkpoints/shared/best.pt
#   v7/model/ensembles/<label>/seed*/best.pt
#   v7/model/calibration/temps.json           (specialist temps)
#   v7/model/calibration/temps_shared.json    (shared temps)
#   v7/model/calibration/thresholds_blend.json (alpha + per-label thresholds)

import os, json, math
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn

BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
DESC_DIR   = BASE / "data" / "descriptors"
MODEL_DIR  = BASE / "model"
CKPT_BEST  = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
ENS_DIR    = MODEL_DIR / "ensembles"
CAL_DIR    = MODEL_DIR / "calibration"

assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"
assert (PREP_DIR / "dataset_manifest.json").exists(), "Missing dataset manifest."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Labels & calibration artifacts ---
ds_manifest = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABEL_NAMES: List[str] = ds_manifest["labels"]
DESC_IN_DIM = ds_manifest["n_features"]  # 208

temps_spec    = json.loads((CAL_DIR / "temps.json").read_text())           # specialist
temps_shared  = json.loads((CAL_DIR / "temps_shared.json").read_text())    # shared
blend_payload = json.loads((CAL_DIR / "thresholds_blend.json").read_text())
ALPHA         = float(blend_payload.get("alpha", 0.8))
thr_blend     = blend_payload["thresholds"]  # label -> {th_f1, th_fbeta15, ap_val}

# --- Text encoder (ChemBERTa) ---
from transformers import AutoTokenizer, AutoModel
class ChemBERTaEncoder(nn.Module):
    def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)
        self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
        self.ln = nn.LayerNorm(fusion_dim)
    def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
        enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                             max_length=max_length, add_special_tokens=add_special_tokens,
                             return_tensors="pt")
        input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
        toks = self.ln(self.proj(out))  # (B,L,256)
        return toks, attention_mask.to(dtype=torch.int32)

# --- Graph encoder (names matched to checkpoint) ---
from rdkit import Chem
ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]

def _one_hot(v, choices):
    z = [0]*len(choices)
    if v in choices: z[choices.index(v)] = 1
    return z

def _bucket_oh(v, lo, hi):
    buckets = list(range(lo, hi+1))
    o = [0]*(len(buckets)+1)
    idx = v - lo
    o[idx if 0 <= idx < len(buckets) else -1] = 1
    return o

def _atom_feat(atom):
    hybs = [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
    chir = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER]
    sym = atom.GetSymbol()
    feat = _one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
    feat += _bucket_oh(atom.GetDegree(), 0, 5)
    feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
    feat += (_one_hot(atom.GetHybridization(), hybs)+[0])
    feat += [int(atom.GetIsAromatic())]
    feat += [int(atom.IsInRing())]
    feat += _one_hot(atom.GetChiralTag(), chir)
    feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
    feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
    feat += [atom.GetMass()/200.0]
    return feat  # ~51 dims

def _smiles_to_graph(smi, max_nodes=128):
    mol = Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms() == 0:
        return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)
    feats = [_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
    x = np.asarray(feats, dtype=np.float32)
    N = mol.GetNumAtoms()
    adj = np.zeros((N, N), dtype=np.float32)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        adj[i, j] = 1.0; adj[j, i] = 1.0
    if N > max_nodes:
        x = x[:max_nodes]; adj = adj[:max_nodes, :max_nodes]
    return x, adj

def _collate_graphs(smiles_batch, max_nodes=128):
    graphs = [_smiles_to_graph(s) for s in smiles_batch]
    Nmax = max([g[0].shape[0] for g in graphs] + [1])
    Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
    B = len(graphs)
    X = np.zeros((B, Nmax, Fnode), dtype=np.float32)
    A = np.zeros((B, Nmax, Nmax), dtype=np.float32)
    M = np.zeros((B, Nmax), dtype=np.int64)
    for i, (x, a) in enumerate(graphs):
        n = x.shape[0]
        if n == 0: continue
        X[i, :n, :] = x
        A[i, :n, :n] = a
        M[i, :n] = 1
    return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

class GINLayer(nn.Module):
    def __init__(self, h=256, p=0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(0.0))
        self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
    def forward(self, x, adj, mask):
        out = (1.0 + self.eps) * x + torch.matmul(adj, x)
        out = self.mlp(out)
        return out * mask.unsqueeze(-1).to(out.dtype)

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
        self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
        self.out_ln = nn.LayerNorm(hidden_dim)  # name matches checkpoint
    def forward(self, smiles_list: List[str], max_nodes=128):
        X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
        h = self.inp(X)
        for layer in self.layers:
            h = layer(h, A, M)
        return self.out_ln(h), M.to(dtype=torch.int32)

# --- Fusion & heads ---
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
    return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=256, n_heads=4, p=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
        self.ln  = nn.LayerNorm(dim)
        self.do  = nn.Dropout(p)
    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        Q = text_tokens.transpose(0,1)   # (L,B,D)
        K = graph_nodes.transpose(0,1)   # (N,B,D)
        V = graph_nodes.transpose(0,1)
        kpm = (graph_mask == 0)          # (B,N)
        attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
        attn = attn.transpose(0,1)       # (B,L,D)
        return self.ln(text_tokens + self.do(attn))

class DescriptorMLP(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
            nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)

class FusionClassifier(nn.Module):
    # name 'mlp' matches checkpoint ('shared_head.mlp.*')
    def __init__(self, dim=256, n_labels=12, p=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, fused_vec): return self.mlp(fused_vec)

class V7FusionModel(nn.Module):
    def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
        super().__init__()
        self.text_encoder=text_encoder
        self.graph_encoder=graph_encoder
        self.cross=CrossAttentionBlock(dim, n_heads, p)
        self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
        self.shared_head=FusionClassifier(dim, n_labels, p)
    def forward(self, smiles_list, desc_feats):
        tt, tm = self.text_encoder(smiles_list, max_length=256)
        gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
        tta = self.cross(tt.to(device), tm.to(device), gn.to(device), gm.to(device))
        de  = self.desc_mlp(desc_feats.to(device))
        text_pool  = masked_mean(tta, tm.to(device), 1)
        graph_pool = masked_mean(gn.to(device),  gm.to(device), 1)
        fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
        logits = self.shared_head(fused)
        return logits, fused

# Build model & load checkpoint
text_encoder = ChemBERTaEncoder().to(device)
graph_encoder= GraphGINEncoder().to(device)
v7_shared    = V7FusionModel(text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, n_labels=len(LABEL_NAMES)).to(device)
ckpt = torch.load(CKPT_BEST, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()

# Specialist heads (same as trained)
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _load_best_head(label: str) -> nn.Module:
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands: raise FileNotFoundError(f"No trained heads for label {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ck["model"], strict=True)
    head.eval()
    return head

HEADS: Dict[str, nn.Module] = {lbl: _load_best_head(lbl) for lbl in LABEL_NAMES}

# Descriptors for ad-hoc inputs: standardized zeros (keeps it simple & robust)
def prepare_desc_matrix(smiles_list: List[str]) -> torch.Tensor:
    Z = np.zeros((len(smiles_list), DESC_IN_DIM), dtype=np.float32)
    return torch.tensor(Z, dtype=torch.float32, device=device)

# Normalize SMARTS→SMILES if needed
def normalize_smiles_or_smarts(s: str) -> str:
    if not isinstance(s, str): s = str(s)
    mol = Chem.MolFromSmiles(s)
    if mol: return Chem.MolToSmiles(mol)
    q = Chem.MolFromSmarts(s)
    if q:
        try:
            smi = Chem.MolToSmiles(q)
            return smi if smi else s
        except Exception:
            return s
    return s

@torch.no_grad()
def fused_from_smiles(smiles_list: List[str]) -> torch.Tensor:
    smiles_list = [normalize_smiles_or_smarts(s) for s in smiles_list]
    desc = prepare_desc_matrix(smiles_list)
    logits_sh, fused = v7_shared(smiles_list, desc)  # logits not used here directly
    return fused  # (B,768)

def predict_one_blend(smi: str, mode: str = "fbeta15", topk: int = 5):
    """
    Blended prediction for one SMILES/SMARTS using:
      prob_blend = alpha*P_spec + (1-alpha)*P_shared
    Thresholds taken from thresholds_blend.json for chosen mode ("f1" or "fbeta15").
    Prints a clean summary and returns a dict[label]->details.
    """
    assert mode in ("f1","fbeta15")
    fused = fused_from_smiles([smi])
    x = fused[0:1]

    # Shared logits and calibrated probs
    with torch.no_grad():
        logits_shared = v7_shared.shared_head(x).detach().cpu().numpy()[0]  # (12,)

    rec = {}
    for j, lbl in enumerate(LABEL_NAMES):
        # Specialist prob (with its temperature)
        with torch.no_grad():
            logit_spec = HEADS[lbl](x).item()
        T_spec   = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
        p_spec   = 1. / (1. + math.e**(-logit_spec / T_spec))

        # Shared prob (with shared temperature)
        T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
        p_shared = 1. / (1. + math.e**(-float(logits_shared[j]) / T_shared))

        # Blend
        p_blend = ALPHA * p_spec + (1.0 - ALPHA) * p_shared

        # Threshold
        th = thr_blend[lbl]["th_fbeta15"] if mode=="fbeta15" else thr_blend[lbl]["th_f1"]
        rec[lbl] = {
            "prob_spec": float(p_spec),
            "prob_shared": float(p_shared),
            "prob_blend": float(p_blend),
            "threshold": float(th),
            "decision": bool(p_blend >= float(th)),
        }

    # Pretty print
    print("\nSMILES/SMARTS:", smi, f"(alpha={ALPHA:.2f}, mode={mode})")
    top = sorted([(lbl, d["prob_blend"], d["decision"]) for lbl, d in rec.items()],
                 key=lambda z: z[1], reverse=True)[:topk]
    for lbl, p, dec in top:
        th = rec[lbl]["threshold"]
        print(f"  {lbl:12s}  prob_blend={p:.3f}  th={th:.3f}  → pred={int(dec)}")
    pos = [lbl for lbl, d in rec.items() if d["decision"]]
    print("  Positives:", (", ".join(sorted(pos)) if pos else "none"))
    return rec

print("✅ Blend test rig ready. Example:")


✅ Blend test rig ready. Example:


In [None]:
predict_one_blend("O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1", mode="fbeta15", topk=12)


SMILES/SMARTS: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1 (alpha=0.80, mode=fbeta15)
  NR-AhR        prob_blend=0.700  th=0.642  → pred=1
  SR-MMP        prob_blend=0.680  th=0.589  → pred=1
  SR-ARE        prob_blend=0.588  th=0.528  → pred=1
  NR-ER         prob_blend=0.556  th=0.480  → pred=1
  SR-p53        prob_blend=0.551  th=0.478  → pred=1
  NR-ER-LBD     prob_blend=0.499  th=0.589  → pred=0
  NR-Aromatase  prob_blend=0.498  th=0.474  → pred=1
  SR-ATAD5      prob_blend=0.479  th=0.483  → pred=0
  NR-PPAR-gamma  prob_blend=0.478  th=0.427  → pred=1
  SR-HSE        prob_blend=0.460  th=0.459  → pred=1
  NR-AR         prob_blend=0.432  th=0.653  → pred=0
  NR-AR-LBD     prob_blend=0.423  th=0.621  → pred=0
  Positives: NR-AhR, NR-Aromatase, NR-ER, NR-PPAR-gamma, SR-ARE, SR-HSE, SR-MMP, SR-p53


{'NR-AR': {'prob_spec': 0.5297900819654926,
  'prob_shared': 0.039244673619722704,
  'prob_blend': 0.43168100029633866,
  'threshold': 0.653282642364502,
  'decision': False},
 'NR-AR-LBD': {'prob_spec': 0.527085290472785,
  'prob_shared': 0.007392770345504927,
  'prob_blend': 0.42314678644732895,
  'threshold': 0.6206690669059753,
  'decision': False},
 'NR-AhR': {'prob_spec': 0.6696848171077073,
  'prob_shared': 0.8209122313828833,
  'prob_blend': 0.6999302999627425,
  'threshold': 0.6417197585105896,
  'decision': True},
 'NR-Aromatase': {'prob_spec': 0.5691290816416109,
  'prob_shared': 0.21573499976865418,
  'prob_blend': 0.4984502652670196,
  'threshold': 0.4737248420715332,
  'decision': True},
 'NR-ER': {'prob_spec': 0.5319725845509377,
  'prob_shared': 0.6530290641575132,
  'prob_blend': 0.5561838804722528,
  'threshold': 0.4802268147468567,
  'decision': True},
 'NR-ER-LBD': {'prob_spec': 0.5796462377733185,
  'prob_shared': 0.17392469045136985,
  'prob_blend': 0.498501928308

In [None]:
predict_one_blend("O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1", mode="f1", topk=12)


SMILES/SMARTS: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1 (alpha=0.80, mode=f1)
  NR-AhR        prob_blend=0.700  th=0.709  → pred=0
  SR-MMP        prob_blend=0.680  th=0.589  → pred=1
  SR-ARE        prob_blend=0.588  th=0.528  → pred=1
  NR-ER         prob_blend=0.556  th=0.547  → pred=1
  SR-p53        prob_blend=0.551  th=0.513  → pred=1
  NR-ER-LBD     prob_blend=0.499  th=0.589  → pred=0
  NR-Aromatase  prob_blend=0.498  th=0.564  → pred=0
  SR-ATAD5      prob_blend=0.479  th=0.483  → pred=0
  NR-PPAR-gamma  prob_blend=0.478  th=0.441  → pred=1
  SR-HSE        prob_blend=0.460  th=0.472  → pred=0
  NR-AR         prob_blend=0.432  th=0.653  → pred=0
  NR-AR-LBD     prob_blend=0.423  th=0.621  → pred=0
  Positives: NR-ER, NR-PPAR-gamma, SR-ARE, SR-MMP, SR-p53


{'NR-AR': {'prob_spec': 0.5297900819654926,
  'prob_shared': 0.039244673619722704,
  'prob_blend': 0.43168100029633866,
  'threshold': 0.653282642364502,
  'decision': False},
 'NR-AR-LBD': {'prob_spec': 0.527085290472785,
  'prob_shared': 0.007392770345504927,
  'prob_blend': 0.42314678644732895,
  'threshold': 0.6206690669059753,
  'decision': False},
 'NR-AhR': {'prob_spec': 0.6696848171077073,
  'prob_shared': 0.8209122313828833,
  'prob_blend': 0.6999302999627425,
  'threshold': 0.7087583541870117,
  'decision': False},
 'NR-Aromatase': {'prob_spec': 0.5691290816416109,
  'prob_shared': 0.21573499976865418,
  'prob_blend': 0.4984502652670196,
  'threshold': 0.5641032457351685,
  'decision': False},
 'NR-ER': {'prob_spec': 0.5319725845509377,
  'prob_shared': 0.6530290641575132,
  'prob_blend': 0.5561838804722528,
  'threshold': 0.547207772731781,
  'decision': True},
 'NR-ER-LBD': {'prob_spec': 0.5796462377733185,
  'prob_shared': 0.17392469045136985,
  'prob_blend': 0.49850192830