# Zinc_modelのAD判定を行う

## ライブラリ読み込み

In [4]:
import os, gc
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Any, Dict, List, Optional, Tuple, Callable

from rdkit import Chem
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_add_pool

## 前処理・モデル構造再定義

In [2]:
# ---------- (A) same constants as training ----------
MAX_ATOMIC_NUM = 100
MAX_DEGREE = 5

BOND_TYPES = {
    Chem.rdchem.BondType.SINGLE: 0,
    Chem.rdchem.BondType.DOUBLE: 1,
    Chem.rdchem.BondType.TRIPLE: 2,
    Chem.rdchem.BondType.AROMATIC: 3,
}
NUM_BOND_TYPES = 4

MIN_FC, MAX_FC = -2, 2
FC_OFFSET = -MIN_FC
NUM_FC = (MAX_FC - MIN_FC + 1)

HYB_MAP = {
    Chem.rdchem.HybridizationType.SP: 0,
    Chem.rdchem.HybridizationType.SP2: 1,
    Chem.rdchem.HybridizationType.SP3: 2,
    Chem.rdchem.HybridizationType.SP3D: 3,
    Chem.rdchem.HybridizationType.SP3D2: 4,
}
HYB_UNKNOWN = 5
NUM_HYB = 6


def atom_features(atom: Chem.rdchem.Atom) -> torch.Tensor:
    atomic_num = min(atom.GetAtomicNum(), MAX_ATOMIC_NUM)
    degree = min(atom.GetDegree(), MAX_DEGREE)
    aromatic = int(atom.GetIsAromatic())

    formal_charge = atom.GetFormalCharge()
    formal_charge = max(MIN_FC, min(formal_charge, MAX_FC)) + FC_OFFSET  # 0..4

    hyb_idx = HYB_MAP.get(atom.GetHybridization(), HYB_UNKNOWN)

    return torch.tensor([atomic_num, degree, aromatic, formal_charge, hyb_idx], dtype=torch.long)


def bond_features(bond: Chem.rdchem.Bond) -> torch.Tensor:
    bond_type = BOND_TYPES.get(bond.GetBondType(), 0)
    conjugated = int(bond.GetIsConjugated())
    in_ring = int(bond.IsInRing())
    return torch.tensor([bond_type, conjugated, in_ring], dtype=torch.long)


def canonicalize_smiles(smi: str) -> Optional[str]:
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    return Chem.MolToSmiles(mol, canonical=True)


def smiles_to_pyg_discrete_v2(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    x = torch.stack([atom_features(a) for a in mol.GetAtoms()], dim=0)  # [N, 5]

    edge_index_list, edge_attr_list = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bf = bond_features(bond)
        edge_index_list += [[i, j], [j, i]]
        edge_attr_list  += [bf, bf]

    if len(edge_index_list) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr  = torch.empty((0, 3), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr  = torch.stack(edge_attr_list, dim=0)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.smiles = smiles
    return data


def serializable_to_scalers(d: Dict[str, Dict[str, float]]) -> Dict[str, Tuple[float, float]]:
    return {k: (float(v["median"]), float(v["iqr"])) for k, v in d.items()}


def inverse_robust_scalar(x_scaled: float, med: float, iqr: float) -> float:
    return float(x_scaled * iqr + med)


# ---------- (B) Model definition (must match training) ----------
class EdgeCondLinearLayer(MessagePassing):
    def __init__(self, hidden_dim: int, num_bond_types: int):
        super().__init__(aggr="add")
        self.hidden_dim = hidden_dim
        self.num_bond_types = num_bond_types

        self.W = nn.Parameter(torch.empty(num_bond_types, hidden_dim, hidden_dim))
        nn.init.xavier_uniform_(self.W)
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)

    def forward(self, x, edge_index, bond_type):
        m = self.propagate(edge_index, x=x, bond_type=bond_type)
        x_out = self.gru(m, x)
        return x_out

    def message(self, x_j, bond_type):
        bt = bond_type.clamp(0, self.num_bond_types - 1)
        W_bt = self.W[bt]  # [E,H,H]
        m = torch.bmm(W_bt, x_j.unsqueeze(-1)).squeeze(-1)
        return m


class MPNNRegressor(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_layers: int = 3, num_targets: int = 3, dropout: float = 0.0):
        super().__init__()

        self.emb_atomic = nn.Embedding(MAX_ATOMIC_NUM + 1, 64)
        self.emb_degree = nn.Embedding(MAX_DEGREE + 1, 16)
        self.emb_aroma  = nn.Embedding(2, 8)
        self.emb_fc     = nn.Embedding(NUM_FC, 8)
        self.emb_hyb    = nn.Embedding(NUM_HYB, 8)

        node_in_dim = 64 + 16 + 8 + 8 + 8  # 104
        self.node_proj = nn.Linear(node_in_dim, hidden_dim)

        self.layers = nn.ModuleList([
            EdgeCondLinearLayer(hidden_dim=hidden_dim, num_bond_types=NUM_BOND_TYPES)
            for _ in range(num_layers)
        ])

        self.fp_dim = 1024
        self.node_to_fp = nn.Linear(hidden_dim, self.fp_dim)

        self.fc1 = nn.Linear(self.fp_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)

        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)

        self.out = nn.Linear(256, num_targets)

        self.dropout = dropout

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr  = data.edge_attr
        batch      = data.batch

        atomic_num      = x[:, 0].clamp(0, MAX_ATOMIC_NUM)
        degree          = x[:, 1].clamp(0, MAX_DEGREE)
        aromatic        = x[:, 2].clamp(0, 1)
        formal_charge   = x[:, 3].clamp(0, NUM_FC - 1)
        hybridization   = x[:, 4].clamp(0, NUM_HYB - 1)

        h = torch.cat([
            self.emb_atomic(atomic_num),
            self.emb_degree(degree),
            self.emb_aroma(aromatic),
            self.emb_fc(formal_charge),
            self.emb_hyb(hybridization),
        ], dim=-1)

        h = self.node_proj(h)

        bond_type = edge_attr[:, 0].clamp(0, NUM_BOND_TYPES - 1)

        for layer in self.layers:
            h = layer(h, edge_index, bond_type)

        h_fp = self.node_to_fp(h)
        g    = global_add_pool(h_fp, batch)

        z = self.bn1(self.fc1(g))
        z = F.relu(z)
        if self.dropout > 0:
            z = F.dropout(z, p=self.dropout, training=self.training)

        z = self.bn2(self.fc2(z))
        z = F.relu(z)
        if self.dropout > 0:
            z = F.dropout(z, p=self.dropout, training=self.training)

        out = self.out(z)
        return out


class GenericGNNPredictor:
    def __init__(self, ckpt_path: str, device: Optional[torch.device] = None):
        self.ckpt_path = ckpt_path
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        payload = torch.load(ckpt_path, map_location="cpu")

        self.dataset_name = payload.get("dataset_name", "unknown")
        self.model_config = payload["model_config"]
        self.preprocess_config = payload["preprocess_config"]
        self.scalers = serializable_to_scalers(payload["scalers"])
        self.target_cols = list(self.model_config["target_cols"])
        self.num_targets = int(self.model_config["num_targets"])

        # build model
        self.model = MPNNRegressor(
            hidden_dim=int(self.model_config["hidden_dim"]),
            num_layers=int(self.model_config["num_layers"]),
            num_targets=int(self.model_config["num_targets"]),
            dropout=float(self.model_config.get("dropout", 0.0)),
        ).to(self.device)

        self.model.load_state_dict(payload["state_dict"])
        self.model.eval()

        self.do_canonicalize = bool(self.preprocess_config.get("smiles_canonicalize", True))

    @torch.no_grad()
    def predict_smiles(self, smiles: str) -> Dict[str, Any]:
        smi_in = smiles
        smi = smiles

        if self.do_canonicalize:
            smi2 = canonicalize_smiles(smi)
            if smi2 is None:
                return {"ok": False, "smiles_in": smi_in, "error": "Invalid SMILES"}
            smi = smi2

        data = smiles_to_pyg_discrete_v2(smi)
        if data is None:
            return {"ok": False, "smiles_in": smi_in, "smiles_used": smi, "error": "Failed to build graph"}

        batch = next(iter(DataLoader([data], batch_size=1, shuffle=False))).to(self.device)
        pred_scaled = self.model(batch).detach().cpu().view(-1)

        pred_raw_dict = {}
        for i, t in enumerate(self.target_cols):
            med, iqr = self.scalers[t]
            pred_raw_dict[t] = inverse_robust_scalar(float(pred_scaled[i].item()), med, iqr)

        return {"ok": True, "smiles_in": smi_in, "smiles_used": smi, "pred_raw": pred_raw_dict}


def load_predictor_generic(dataset_name: str, save_root: str = "models", ckpt_name: str = "checkpoint.pt",
                          device: Optional[torch.device] = None) -> GenericGNNPredictor:
    ckpt_path = os.path.join(save_root, dataset_name, ckpt_name)
    return GenericGNNPredictor(ckpt_path=ckpt_path, device=device)


print("Imports/definitions OK")


Imports/definitions OK


In [3]:
# Load predictor & sanity check

# checkpoint から predictor をロード
predictor = load_predictor_generic(
    dataset_name="zinc",     # models/zinc/checkpoint.pt
    save_root="models",
    ckpt_name="checkpoint.pt",
)

print("Loaded predictor for dataset:", predictor.dataset_name)
print("Targets:", predictor.target_cols)

# 簡単なSMILESで予測テスト
res = predictor.predict_smiles("CCO")  # ethanol

print("Prediction result:")
print(res)

Loaded predictor for dataset: zinc
Targets: ['logP', 'qed', 'SAS']
Prediction result:
{'ok': True, 'smiles_in': 'CCO', 'smiles_used': 'CCO', 'pred_raw': {'logP': 0.25158007044553754, 'qed': 0.35975909160570496, 'SAS': 2.525134175365807}}


## グラフ埋め込みを抽出・AD範囲決定・AD判定

In [5]:
@torch.no_grad()
def embed_smiles(predictor, smiles: str):
    """
    学習済みGNNから、分子fingerprint g（1024次元）を取得する。
    予測ヘッドは通さない。AD用。
    """
    smi = smiles
    if predictor.do_canonicalize:
        smi2 = canonicalize_smiles(smi)
        if smi2 is None:
            return None
        smi = smi2

    data = smiles_to_pyg_discrete_v2(smi)
    if data is None:
        return None

    # Batch of 1
    batch = next(iter(DataLoader([data], batch_size=1, shuffle=False))).to(predictor.device)

    m = predictor.model
    m.eval()

    # --- forward の readout 直前まで ---
    x = batch.x
    edge_index = batch.edge_index
    edge_attr  = batch.edge_attr
    batch_vec  = batch.batch

    atomic_num      = x[:, 0].clamp(0, MAX_ATOMIC_NUM)
    degree          = x[:, 1].clamp(0, MAX_DEGREE)
    aromatic        = x[:, 2].clamp(0, 1)
    formal_charge   = x[:, 3].clamp(0, NUM_FC - 1)
    hybridization   = x[:, 4].clamp(0, NUM_HYB - 1)

    h = torch.cat([
        m.emb_atomic(atomic_num),
        m.emb_degree(degree),
        m.emb_aroma(aromatic),
        m.emb_fc(formal_charge),
        m.emb_hyb(hybridization),
    ], dim=-1)

    h = m.node_proj(h)

    bond_type = edge_attr[:, 0].clamp(0, NUM_BOND_TYPES - 1)
    
    for layer in m.layers:
        h = layer(h, edge_index, bond_type)

    h_fp = m.node_to_fp(h)                 # [N, 1024]
    g    = global_add_pool(h_fp, batch_vec)  # [1, 1024]

    return g.cpu().view(-1).numpy()         # (1024,)

In [6]:
train_csv_path = 'data/zinc/zinc250k_train_processed.csv'
smiles_col = 'smiles'

out_dir = "models/ad_artifacts/zinc"
os.makedirs(out_dir, exist_ok=True)

g_tmp_path   = os.path.join(out_dir, "G_train_tmp.npy")   # まず最大行で確保（後で切り詰め）
g_final_path = os.path.join(out_dir, "G_train.npy")
ok_path      = os.path.join(out_dir, "train_smiles_ok.csv")
bad_path     = os.path.join(out_dir, "train_smiles_bad.csv")

# 0) 総行数をざっくり数える（ヘッダ除く）
with open(train_csv_path, "r") as f:
    n_total = sum(1 for _ in f) - 1
print("num smiles (rows):", n_total)

# 1) 埋め込み次元を決定（先頭チャンクから1個成功するまで探す）
# チャンク（分割）で読み込む
emb_dim = None
head = pd.read_csv(train_csv_path, usecols=[smiles_col], nrows=2000)
for smi in head[smiles_col].astype(str).tolist():
    g = embed_smiles(predictor, smi)
    if g is None:
        continue
    g = np.asarray(g).reshape(-1)
    emb_dim = int(g.shape[0])
    break
if emb_dim is None:
    raise RuntimeError("Cannot determine embedding dim: no valid smiles in first 2000 rows.")
print("embedding dim:", emb_dim)

# 2) 出力配列を「最大n_total行」でmemmap確保（float32固定）
G_out = np.lib.format.open_memmap(
    g_tmp_path, mode="w+", dtype=np.float32, shape=(n_total, emb_dim)
)

# 3) OK/NG を逐次でCSV出力（リストで溜めない）
with open(ok_path, "w") as f_ok, open(bad_path, "w") as f_bad:
    f_ok.write("smiles\n")
    f_bad.write("smiles\n")

    write_idx = 0
    bad_count = 0

    chunksize = 2000  # 安全寄り。速ければ5000~20000に増やしてOK
    for chunk in pd.read_csv(train_csv_path, usecols=[smiles_col], chunksize=chunksize):
        for smi in chunk[smiles_col].astype(str).tolist():
            g = embed_smiles(predictor, smi)
            if g is None:
                bad_count += 1
                f_bad.write(smi + "\n")
                continue

            g = np.asarray(g, dtype=np.float32).reshape(-1)
            if g.shape[0] != emb_dim:
                bad_count += 1
                f_bad.write(smi + "\n")
                continue

            G_out[write_idx, :] = g
            f_ok.write(smi + "\n")
            write_idx += 1

        # チャンク単位でフラッシュ＆GC（落ちにくくする）
        G_out.flush()
        gc.collect()

        if (write_idx + bad_count) % 20000 < chunksize:
            print(f"progress: ok={write_idx} bad={bad_count}")

# 4) OK行数ぶんだけ final に切り詰めコピー（チャンクで）
G_out.flush()
del G_out
gc.collect()

src = np.load(g_tmp_path, mmap_mode="r")
dst = np.lib.format.open_memmap(
    g_final_path, mode="w+", dtype=np.float32, shape=(write_idx, emb_dim)
)

step = 20000
for s in range(0, write_idx, step):
    e = min(write_idx, s + step)
    dst[s:e] = src[s:e]
dst.flush()

del src, dst
gc.collect()

# 5) tmp消しても良い（残すならコメントアウト）
try:
    os.remove(g_tmp_path)
except Exception as e:
    print("warn: failed to remove tmp:", e)

print("embedded:", write_idx, "/", n_total)
print("bad:", bad_count)

print("Saved:")
print(" -", g_final_path)
print(" -", ok_path)
print(" -", bad_path)

num smiles (rows): 399128
embedding dim: 1024
progress: ok=20000 bad=0
progress: ok=40000 bad=0
progress: ok=60000 bad=0
progress: ok=80000 bad=0
progress: ok=100000 bad=0
progress: ok=120000 bad=0
progress: ok=140000 bad=0
progress: ok=160000 bad=0
progress: ok=180000 bad=0
embedded: 199564 / 399128
bad: 0
Saved:
 - models/ad_artifacts/zinc/G_train.npy
 - models/ad_artifacts/zinc/train_smiles_ok.csv
 - models/ad_artifacts/zinc/train_smiles_bad.csv


In [7]:
# マハラノビス距離計算、閾値設定（落ちにくい版：mmap + チャンク2パス）

ad_dir = "models/ad_artifacts/zinc"
g_path = os.path.join(ad_dir, "G_train.npy")

# 1) load embeddings (mmap)
G_tr = np.load(g_path, mmap_mode="r")
N, D = G_tr.shape
print("G_tr shape:", G_tr.shape, "dtype:", G_tr.dtype)

# 2) robust scaling parameters (median / IQR)  ※結果は(1024,)なので軽い
median = np.median(G_tr, axis=0)
q1 = np.percentile(G_tr, 25, axis=0)
q3 = np.percentile(G_tr, 75, axis=0)
iqr = q3 - q1
iqr_safe = np.where(iqr == 0.0, 1.0, iqr)

# 3) mu をチャンクで計算（G_sを作らない）
chunk = 20000  # 余裕あれば 50000 などへ
mu = np.zeros(D, dtype=np.float64)

for s in range(0, N, chunk):
    e = min(N, s + chunk)
    Xs = (G_tr[s:e].astype(np.float64) - median) / iqr_safe
    mu += Xs.sum(axis=0)
mu /= N
print("mu computed")

# 4) cov をチャンクで計算（Xを作らない）
S = np.zeros((D, D), dtype=np.float64)
for s in range(0, N, chunk):
    e = min(N, s + chunk)
    Xs = (G_tr[s:e].astype(np.float64) - median) / iqr_safe
    Xc = Xs - mu
    S += Xc.T @ Xc
cov = S / N

# 数値安定化
eps = 1e-3
cov += eps * np.eye(D, dtype=np.float64)

inv_cov = np.linalg.inv(cov)
print("inv_cov computed")

# 5) Mahalanobis distance をチャンクで計算
# percentile が必要なので md_tr を memmap に保存してから percentile を取る
md_path = os.path.join(ad_dir, "md_train.npy")
md_mm = np.lib.format.open_memmap(md_path, mode="w+", dtype=np.float64, shape=(N,))

for s in range(0, N, chunk):
    e = min(N, s + chunk)
    Xs = (G_tr[s:e].astype(np.float64) - median) / iqr_safe
    Xc = Xs - mu
    md_mm[s:e] = np.einsum("bi,ij,bj->b", Xc, inv_cov, Xc)

md_mm.flush()
del md_mm
gc.collect()

md_tr = np.load(md_path, mmap_mode="r")

print("Mahalanobis stats (train):")
print("  min:", float(md_tr.min()))
print("  mean:", float(md_tr.mean()))
print("  max:", float(md_tr.max()))

# 6) threshold: top 2% as out (same as ZINC)
out_percent = 2.0
thr = float(np.percentile(md_tr, 100.0 - out_percent))

print("Mahalanobis stats (train): min/mean/max =", float(md_tr.min()), float(md_tr.mean()), float(md_tr.max()))
print(f"thr (top {out_percent}%):", thr)
print("train out-rate:", float(np.mean(md_tr > thr)))

# 7) save artifacts
np.save(os.path.join(ad_dir, "md_mu.npy"), mu.astype(np.float32))
np.save(os.path.join(ad_dir, "md_inv_cov.npy"), inv_cov.astype(np.float32))
np.save(os.path.join(ad_dir, "md_thr.npy"), np.array([thr], dtype=np.float32))

np.save(os.path.join(ad_dir, "md_median.npy"), median.astype(np.float32))
np.save(os.path.join(ad_dir, "md_iqr.npy"), iqr_safe.astype(np.float32))

print("Saved:")
print(" - md_mu.npy / md_inv_cov.npy / md_thr.npy")
print(" - md_median.npy / md_iqr.npy")

G_tr shape: (199564, 1024) dtype: float32
mu computed
inv_cov computed
Mahalanobis stats (train):
  min: 31.996361389189033
  mean: 124.53465056014444
  max: 2552.217068137432
Mahalanobis stats (train): min/mean/max = 31.996361389189033 124.53465056014444 2552.217068137432
thr (top 2.0%): 252.2586172494561
train out-rate: 0.020003607865146017
Saved:
 - md_mu.npy / md_inv_cov.npy / md_thr.npy
 - md_median.npy / md_iqr.npy


In [8]:
# テストデータでAD評価
# ---- load AD artifacts ----
ad_dir = "models/ad_artifacts/zinc"

mu = np.load(os.path.join(ad_dir, "md_mu.npy"))
inv_cov = np.load(os.path.join(ad_dir, "md_inv_cov.npy"))
thr = float(np.load(os.path.join(ad_dir, "md_thr.npy"))[0])

mean  = np.load(os.path.join(ad_dir, "md_median.npy"))
scale = np.load(os.path.join(ad_dir, "md_iqr.npy"))

# safety
iqr_safe = np.where(scale == 0.0, 1.0, scale)

def mahalanobis_distance_from_g(G, mean, scale, mu, inv_cov):
    """
    G: (N, D) raw embedding
    standardize -> subtract mu -> mahalanobis
    """
    Gs = (G - mean) / scale
    D = Gs - mu.reshape(1, -1)
    md = np.einsum("bi,ij,bj->b", D, inv_cov, D)
    return md

# ---- load test dataframe ----
test_csv_path = "data/zinc/zinc250k_test_processed.csv"   # ←あなたの実ファイル名に合わせて
smiles_col = "smiles"

test_df = pd.read_csv(test_csv_path)

# 目的変数（raw）列名：あなたのdfに合わせる
target_cols = predictor.target_cols  # ['logP', 'qed', 'SAS']

# ---- embed test smiles ----
smiles_list = test_df[smiles_col].astype(str).tolist()

G_list = []
ok_idx = []
bad_idx = []

for i, smi in enumerate(smiles_list):
    g = embed_smiles(predictor, smi)
    if g is None:
        bad_idx.append(i)
        continue
    G_list.append(g)
    ok_idx.append(i)

G_te = np.stack(G_list, axis=0)  # (N_ok, 1024)

# ---- AD score & label ----
md_te = mahalanobis_distance_from_g(G_te, median, iqr_safe, mu, inv_cov)
ad_label = np.where(md_te > thr, "out", "in")

print("TEST embedded:", len(ok_idx), "/", len(test_df), "bad:", len(bad_idx))
print("AD-out rate (test):", np.mean(ad_label == "out"))

# ---- GNN prediction (raw) for same ok samples ----
# predictor.predict_smiles は1件ずつで遅いので、ここでは簡易にループで回す（まずは検証優先）
pred_rows = []
for i in ok_idx:
    smi = smiles_list[i]
    res = predictor.predict_smiles(smi)
    pred_rows.append(res["pred_raw"])

pred_df = pd.DataFrame(pred_rows)  # columns: logP,qed,SAS

# true values (raw) for ok samples
true_df = test_df.iloc[ok_idx][target_cols].reset_index(drop=True)

# ---- error metrics ----
abs_err = (pred_df[target_cols].values - true_df[target_cols].values)
mae_each = np.mean(np.abs(abs_err), axis=0)              # per target
mae_macro = np.mean(np.abs(abs_err), axis=1)             # per sample macro-MAE

# ---- in/out split ----
mask_in = (ad_label == "in")
mask_out = (ad_label == "out")

print("\nMacro-MAE (raw) by AD label")
print("  in : mean =", mae_macro[mask_in].mean(), "n =", mask_in.sum())
print("  out: mean =", mae_macro[mask_out].mean(), "n =", mask_out.sum())

print("\nPer-target MAE (raw) by AD label")
for j, t in enumerate(target_cols):
    print(f"  {t:4s} | in: {np.mean(np.abs(abs_err[mask_in, j])):.4f} | out: {np.mean(np.abs(abs_err[mask_out, j])):.4f}")

# ---- attach summary table (optional) ----
summary = test_df.iloc[ok_idx][[smiles_col] + target_cols].reset_index(drop=True).copy()
for t in target_cols:
    summary[f"pred_{t}"] = pred_df[t].values
summary["md"] = md_te
summary["ad_label"] = ad_label
summary["mae_macro"] = mae_macro

display(summary.head(10))

TEST embedded: 24946 / 24946 bad: 0
AD-out rate (test): 0.021165717950773672

Macro-MAE (raw) by AD label
  in : mean = 0.07975466540045895 n = 24418
  out: mean = 0.16074348794515753 n = 528

Per-target MAE (raw) by AD label
  logP | in: 0.1017 | out: 0.2113
  qed  | in: 0.0255 | out: 0.0554
  SAS  | in: 0.1120 | out: 0.2156


Unnamed: 0,smiles,logP,qed,SAS,pred_logP,pred_qed,pred_SAS,md,ad_label,mae_macro
0,CCC1([C@H]([NH3+])c2ccc(F)c(C)c2)CCCC1\n,3.38752,0.816601,3.559824,3.313812,0.827877,3.486692,132.110855,in,0.052705
1,COc1ccc(-c2nnc(S[C@H](C(=O)NC(N)=O)C(C)C)n2[C@...,4.0165,0.60051,3.439721,3.901112,0.658551,3.379234,143.705856,in,0.077972
2,CCOc1ccc(NC(=O)N(C)C[C@H]2CCCO2)c(C(F)(F)F)c1\n,3.7468,0.85783,2.706953,3.704649,0.864768,2.786199,91.791649,in,0.042778
3,Nn1c(SCC(=O)NCCc2ccc(Cl)cc2)nnc1-c1ccc(Cl)cc1\n,3.4167,0.441151,2.052493,3.633578,0.44766,2.193356,139.434952,in,0.121417
4,CC[C@@]1(c2ccccc2)NC(=O)N(NC(=O)c2cc(C)no2)C1=O\n,1.48512,0.835476,2.971892,1.195782,0.813572,2.96791,157.378967,in,0.105074
5,Nc1nc(-c2cccc(F)c2)c(Br)s1\n,3.2939,0.850507,2.322955,3.409121,0.805382,2.371657,77.741875,in,0.069683
6,O=C1Cc2cc(NC(=O)c3ccccc3C(=O)c3ccccc3)ccc2N1\n,3.6645,0.681828,1.979296,3.736431,0.673051,1.831247,73.775932,in,0.076252
7,COCCCn1c(C(=O)Nc2ccc(C)cc2)cc2c(=O)n3ccccc3nc21\n,3.24642,0.504632,2.336396,3.28628,0.516607,2.318305,126.285828,in,0.023309
8,CC1CC[NH+](CN2C(=O)[C@@](O)([C@@H]3SC(N)=NC3=O...,-0.5502,0.684974,4.657024,-0.336659,0.682026,4.913285,222.293198,in,0.157583
9,CO[C@H]1CN(C(=O)COc2cccc(C)c2C)CC[C@H]1[NH3+]\n,0.54004,0.898171,3.463921,0.558493,0.863342,3.578106,112.2229,in,0.055822
