# NRELのAD判定を行う
- 学習したGNN(`NREL_multi_task.ipynb`)のcheckpointを読み込んで、AD判定を行う

## ライブラリ読み込み

In [1]:
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Any, Dict, List, Optional, Tuple, Callable

from rdkit import Chem
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_add_pool

## 前処理・モデル構造再定義

In [2]:
# ---------- (A) same constants as training ----------
MAX_ATOMIC_NUM = 100
MAX_DEGREE = 5

BOND_TYPES = {
    Chem.rdchem.BondType.SINGLE: 0,
    Chem.rdchem.BondType.DOUBLE: 1,
    Chem.rdchem.BondType.TRIPLE: 2,
    Chem.rdchem.BondType.AROMATIC: 3,
}
NUM_BOND_TYPES = 4

MIN_FC, MAX_FC = -2, 2
FC_OFFSET = -MIN_FC
NUM_FC = (MAX_FC - MIN_FC + 1)

HYB_MAP = {
    Chem.rdchem.HybridizationType.SP: 0,
    Chem.rdchem.HybridizationType.SP2: 1,
    Chem.rdchem.HybridizationType.SP3: 2,
    Chem.rdchem.HybridizationType.SP3D: 3,
    Chem.rdchem.HybridizationType.SP3D2: 4,
}
HYB_UNKNOWN = 5
NUM_HYB = 6

def atom_features(atom: Chem.rdchem.Atom) -> torch.Tensor:
    atomic_num = min(atom.GetAtomicNum(), MAX_ATOMIC_NUM)
    degree = min(atom.GetDegree(), MAX_DEGREE)
    aromatic = int(atom.GetIsAromatic())

    formal_charge = atom.GetFormalCharge()
    formal_charge = max(MIN_FC, min(formal_charge, MAX_FC)) + FC_OFFSET  # 0..4

    hyb_idx = HYB_MAP.get(atom.GetHybridization(), HYB_UNKNOWN)

    return torch.tensor([atomic_num, degree, aromatic, formal_charge, hyb_idx], dtype=torch.long)

def bond_features(bond: Chem.rdchem.Bond) -> torch.Tensor:
    bond_type = BOND_TYPES.get(bond.GetBondType(), 0)
    conjugated = int(bond.GetIsConjugated())
    in_ring = int(bond.IsInRing())
    return torch.tensor([bond_type, conjugated, in_ring], dtype=torch.long)

def canonicalize_smiles(smi: str) -> Optional[str]:
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    return Chem.MolToSmiles(mol, canonical=True)

def smiles_to_pyg_discrete_v2(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    x = torch.stack([atom_features(a) for a in mol.GetAtoms()], dim=0)  # [N, 5]

    edge_index_list, edge_attr_list = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bf = bond_features(bond)
        edge_index_list += [[i, j], [j, i]]
        edge_attr_list  += [bf, bf]

    if len(edge_index_list) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr  = torch.empty((0, 3), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr  = torch.stack(edge_attr_list, dim=0)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.smiles = smiles
    return data

def serializable_to_scalers(d: Dict[str, Dict[str, float]]) -> Dict[str, Tuple[float, float]]:
    return {k: (float(v["median"]), float(v["iqr"])) for k, v in d.items()}

def inverse_robust_scalar(x_scaled: float, med: float, iqr: float) -> float:
    return float(x_scaled * iqr + med)

# ---------- (B) Model definition (must match training) ----------
class EdgeCondLinearLayer(MessagePassing):
    def __init__(self, hidden_dim: int, num_bond_types: int):
        super().__init__(aggr="add")
        self.hidden_dim = hidden_dim
        self.num_bond_types = num_bond_types

        self.W = nn.Parameter(torch.empty(num_bond_types, hidden_dim, hidden_dim))
        nn.init.xavier_uniform_(self.W)
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)

    def forward(self, x, edge_index, bond_type):
        m = self.propagate(edge_index, x=x, bond_type=bond_type)
        x_out = self.gru(m, x)
        return x_out

    def message(self, x_j, bond_type):
        bt = bond_type.clamp(0, self.num_bond_types - 1)
        W_bt = self.W[bt]  # [E,H,H]
        m = torch.bmm(W_bt, x_j.unsqueeze(-1)).squeeze(-1)
        return m

class MPNNRegressor(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_layers: int = 3, num_targets: int = 3, dropout: float = 0.0):
        super().__init__()

        self.emb_atomic = nn.Embedding(MAX_ATOMIC_NUM + 1, 64)
        self.emb_degree = nn.Embedding(MAX_DEGREE + 1, 16)
        self.emb_aroma  = nn.Embedding(2, 8)
        self.emb_fc     = nn.Embedding(NUM_FC, 8)
        self.emb_hyb    = nn.Embedding(NUM_HYB, 8)

        node_in_dim = 64 + 16 + 8 + 8 + 8  # 104
        self.node_proj = nn.Linear(node_in_dim, hidden_dim)

        self.layers = nn.ModuleList([
            EdgeCondLinearLayer(hidden_dim=hidden_dim, num_bond_types=NUM_BOND_TYPES)
            for _ in range(num_layers)
        ])

        self.fp_dim = 1024
        self.node_to_fp = nn.Linear(hidden_dim, self.fp_dim)

        self.fc1 = nn.Linear(self.fp_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)

        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)

        self.out = nn.Linear(256, num_targets)

        self.dropout = dropout

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr  = data.edge_attr
        batch      = data.batch

        atomic_num      = x[:, 0].clamp(0, MAX_ATOMIC_NUM)
        degree          = x[:, 1].clamp(0, MAX_DEGREE)
        aromatic        = x[:, 2].clamp(0, 1)
        formal_charge   = x[:, 3].clamp(0, NUM_FC - 1)
        hybridization   = x[:, 4].clamp(0, NUM_HYB - 1)

        h = torch.cat([
            self.emb_atomic(atomic_num),
            self.emb_degree(degree),
            self.emb_aroma(aromatic),
            self.emb_fc(formal_charge),
            self.emb_hyb(hybridization),
        ], dim=-1)

        h = self.node_proj(h)

        bond_type = edge_attr[:, 0].clamp(0, NUM_BOND_TYPES - 1)

        for layer in self.layers:
            h = layer(h, edge_index, bond_type)

        h_fp = self.node_to_fp(h)
        g    = global_add_pool(h_fp, batch)

        z = self.bn1(self.fc1(g))
        z = F.relu(z)
        if self.dropout > 0:
            z = F.dropout(z, p=self.dropout, training=self.training)

        z = self.bn2(self.fc2(z))
        z = F.relu(z)
        if self.dropout > 0:
            z = F.dropout(z, p=self.dropout, training=self.training)

        out = self.out(z)
        return out

class GenericGNNPredictor:
    def __init__(self, ckpt_path: str, device: Optional[torch.device] = None):
        self.ckpt_path = ckpt_path
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        payload = torch.load(ckpt_path, map_location="cpu")

        self.dataset_name = payload.get("dataset_name", "unknown")
        self.model_config = payload["model_config"]
        self.preprocess_config = payload["preprocess_config"]
        self.scalers = serializable_to_scalers(payload["scalers"])
        self.target_cols = list(self.model_config["target_cols"])
        self.num_targets = int(self.model_config["num_targets"])

        # build model
        self.model = MPNNRegressor(
            hidden_dim=int(self.model_config["hidden_dim"]),
            num_layers=int(self.model_config["num_layers"]),
            num_targets=int(self.model_config["num_targets"]),
            dropout=float(self.model_config.get("dropout", 0.0)),
        ).to(self.device)

        self.model.load_state_dict(payload["state_dict"])
        self.model.eval()

        self.do_canonicalize = bool(self.preprocess_config.get("smiles_canonicalize", True))

    @torch.no_grad()
    def predict_smiles(self, smiles: str) -> Dict[str, Any]:
        """
        1つの SMILES に対して GNN 推論を行い、
        robust scaling を逆変換した raw 値を返す。

        Returns
        -------
        dict:
            {
              "ok": bool,
              "smiles_in": 入力 SMILES,
              "smiles_used": 正規化後 SMILES,
              "pred_raw": {target: value}
            }
        """

        smi_in = smiles          # ユーザ入力そのまま
        smi = smiles

        # --- canonical SMILES 化 ---
        if self.do_canonicalize:
            smi2 = canonicalize_smiles(smi)
            if smi2 is None:
                # RDKit で解釈不能な SMILES
                return {
                    "ok": False,
                    "smiles_in": smi_in,
                    "error": "Invalid SMILES"
                }
            smi = smi2

        # --- SMILES → PyG Data ---
        data = smiles_to_pyg_discrete_v2(smi)
        if data is None:
            return {
                "ok": False,
                "smiles_in": smi_in,
                "smiles_used": smi,
                "error": "Failed to build graph"
            }

        # --- batch 化（1分子だけだが PyG の都合上必要） ---
        batch = next(
            iter(DataLoader([data], batch_size=1, shuffle=False))
        ).to(self.device)

        # --- GNN 推論（scaled 出力） ---
        pred_scaled = self.model(batch).detach().cpu().view(-1)

        # --- robust scaling を逆変換 ---
        pred_raw_dict = {}
        for i, t in enumerate(self.target_cols):
            med, iqr = self.scalers[t]
            pred_raw_dict[t] = inverse_robust_scalar(
                float(pred_scaled[i].item()), med, iqr
            )

        return {
            "ok": True,
            "smiles_in": smi_in,
            "smiles_used": smi,
            "pred_raw": pred_raw_dict
        }

def load_predictor_generic(
    dataset_name: str,
    save_root: str = "models",
    ckpt_name: str = "checkpoint.pt",
    device: Optional[torch.device] = None
) -> GenericGNNPredictor:
    """
    models/{dataset_name}/{ckpt_name} にある checkpoint から
    GenericGNNPredictor を生成するユーティリティ関数。

    GUI / notebook / CLI で共通 API にするための薄いラッパー。
    """
    ckpt_path = os.path.join(save_root, dataset_name, ckpt_name)
    return GenericGNNPredictor(ckpt_path=ckpt_path, device=device)

print("Imports/definitions OK")

Imports/definitions OK


In [3]:
# Load predictor & sanity check

# checkpoint から predictor をロード
predictor = load_predictor_generic(
    dataset_name="nrel",
    save_root="models",
    ckpt_name="checkpoint.pt",
)

print("Loaded predictor for dataset:", predictor.dataset_name)
print("Targets:", predictor.target_cols)

# 簡単なSMILESで予測テスト
res = predictor.predict_smiles("CCO")  # ethanol

print("Prediction result:")
print(res)

Loaded predictor for dataset: nrel
Targets: ['gap', 'homo', 'lumo', 'spectral_overlap', 'homo_extrapolated', 'lumo_extrapolated', 'gap_extrapolated', 'optical_lumo_extrapolated']
Prediction result:
{'ok': True, 'smiles_in': 'CCO', 'smiles_used': 'CCO', 'pred_raw': {'gap': 5.1179975324630735, 'homo': -6.481625464229649, 'lumo': -0.7308945234565236, 'spectral_overlap': -422.84538336668584, 'homo_extrapolated': -5.616125647804598, 'lumo_extrapolated': -2.1236128026149608, 'gap_extrapolated': 3.174334164434671, 'optical_lumo_extrapolated': -2.436118149549048}}


In [None]:
# Check preprocess consistency from checkpoint
ckpt_path = os.path.join("models", "nrel", "checkpoint.pt")
payload = torch.load(ckpt_path, map_location="cpu")

print("dataset_name:", payload.get("dataset_name"))
print("model_name  :", payload.get("model_config", {}).get("model_name"))
print("preprocess_name:", payload.get("preprocess_config", {}).get("preprocess_name"))

# 念のため、ここが True になっているか確認
print("canonicalize:", payload.get("preprocess_config", {}).get("smiles_canonicalize", None))

# 期待している前処理名
expected = "discrete_atom_bond_v2"
print("matches expected?", payload.get("preprocess_config", {}).get("preprocess_name") == expected)

dataset_name: nrel
model_name  : mpnn_edgecond_gru
preprocess_name: discrete_atom_bond_v2
canonicalize: True
matches expected? True


## グラフ埋め込みを抽出・AD範囲決定・AD判定
- readout関数の出力(global_add_pool後の出力)、`g(1024, )`を取り出して、ここからADの範囲を作る
- 手順として、

1. 学習データにおけるgを取り出し、マハラノビス距離を計算する
2. テストデータのgを取り出し、マハラノビス距離を計算する
3. **学習データにおけるマハラノビス距離のうち、上位2%をAD外として、テストデータに当てはめ、AD内外を判定する**

- マハラノビス距離を使う理由として、ユークリッド距離に比べ、データの分布（`共分散行列`）を考慮することで、データのばらつき具合や相関係数を反映した距離を計算できるため
- OCSVMでも良いが、こちらはハイパーパラメータに敏感

In [None]:
@torch.no_grad()
def embed_smiles(predictor, smiles: str):
    """
    学習済みGNNから、分子fingerprint g（1024次元）を取得する。
    予測ヘッド(MLP)は通さない。AD用。
    """
    smi = smiles
    if predictor.do_canonicalize:
        smi2 = canonicalize_smiles(smi)
        if smi2 is None:
            return None
        smi = smi2

    data = smiles_to_pyg_discrete_v2(smi)
    if data is None:
        return None

    # Batch of 1
    batch = next(iter(DataLoader([data], batch_size=1, shuffle=False))).to(predictor.device)

    m = predictor.model
    m.eval()

    # --- forward の readout 直前まで ---
    x = batch.x
    edge_index = batch.edge_index
    edge_attr  = batch.edge_attr
    batch_vec  = batch.batch

    atomic_num      = x[:, 0].clamp(0, MAX_ATOMIC_NUM)
    degree          = x[:, 1].clamp(0, MAX_DEGREE)
    aromatic        = x[:, 2].clamp(0, 1)
    formal_charge   = x[:, 3].clamp(0, NUM_FC - 1)
    hybridization   = x[:, 4].clamp(0, NUM_HYB - 1)

    h = torch.cat([
        m.emb_atomic(atomic_num),
        m.emb_degree(degree),
        m.emb_aroma(aromatic),
        m.emb_fc(formal_charge),
        m.emb_hyb(hybridization),
    ], dim=-1)

    h = m.node_proj(h)

    bond_type = edge_attr[:, 0].clamp(0, NUM_BOND_TYPES - 1)
    
    for layer in m.layers:
        h = layer(h, edge_index, bond_type)

    h_fp = m.node_to_fp(h)                 # [N, 1024]
    g    = global_add_pool(h_fp, batch_vec)  # [1, 1024]

    return g.cpu().view(-1).numpy()         # (1024,)


# ---- 動作確認 ----
g = embed_smiles(predictor, "CCO")

print("g shape:", None if g is None else g.shape)
print("g[:5] =", g[:5])

g shape: (1024,)
g[:5] = [-13.123536    -0.49069118  -4.005358     6.9682674    7.9609346 ]


In [None]:
# 学習データにて、smilesの埋め込みを計算し、matrix G(学習データにおけるgの集まり) を保存する

train_csv_path = 'data/NREL/nrel_train_processed.csv'
smiles_col = 'smile'

train_df = pd.read_csv(train_csv_path)
smiles_list = train_df[smiles_col].astype(str).tolist()

print("num smiles:", len(smiles_list), "example:", smiles_list[0])

# 埋め込みを作る（失敗SMILESは除外していく）
G_list = []
ok_smiles = []
bad_smiles = []

for smi in smiles_list:
    g = embed_smiles(predictor, smi)
    if g is None:
        bad_smiles.append(smi)
        continue
    G_list.append(g)
    ok_smiles.append(smi)

G = np.stack(G_list, axis=0)  # (N_ok, 1024)

print("embedded:", G.shape[0], "/", len(smiles_list))
print("bad:", len(bad_smiles))

# 保存（AD notebook間で使い回せる）
out_dir = "models/ad_artifacts/nrel"
os.makedirs(out_dir, exist_ok=True)

np.save(os.path.join(out_dir, "G_train.npy"), G)

# 後で対応づけできるよう、OK/NG SMILESも保存
pd.Series(ok_smiles, name="smiles").to_csv(os.path.join(out_dir, "train_smiles_ok.csv"), index=False)
pd.Series(bad_smiles, name="smiles").to_csv(os.path.join(out_dir, "train_smiles_bad.csv"), index=False)

print("Saved:")
print(" -", os.path.join(out_dir, "G_train.npy"))
print(" -", os.path.join(out_dir, "train_smiles_ok.csv"))
print(" -", os.path.join(out_dir, "train_smiles_bad.csv"))

num smiles: 43468 example: COC(=O)c1nc2c(-c3ccc(-c4cccs4)s3)sc(-c3cccs3)c2o1
embedded: 43468 / 43468
bad: 0
Saved:
 - models/ad_artifacts/nrel/G_train.npy
 - models/ad_artifacts/nrel/train_smiles_ok.csv
 - models/ad_artifacts/nrel/train_smiles_bad.csv


In [7]:
# マハラノビス距離の計算と閾値設定

# 1) load embeddings
ad_dir = "models/ad_artifacts/nrel"
G_tr = np.load(os.path.join(ad_dir, "G_train.npy"))   # (N, 1024)
print("G_tr shape:", G_tr.shape)

# 2) robust scaling parameters (median / IQR)
median = np.median(G_tr, axis=0)
q1 = np.percentile(G_tr, 25, axis=0)
q3 = np.percentile(G_tr, 75, axis=0)
iqr = q3 - q1

# safety: iqr == 0 -> 1.0
iqr_safe = np.where(iqr == 0.0, 1.0, iqr)

# 3) robust-standardize
G_s = (G_tr - median) / iqr_safe

# 4) estimate mu/cov in standardized space
mu = G_s.mean(axis=0)
X  = G_s - mu

# 共分散行列
cov = np.cov(X, rowvar=False)
eps = 1e-3
cov += eps * np.eye(cov.shape[0])
inv_cov = np.linalg.inv(cov)

# 5) Mahalanobis distance on train
md_tr = np.einsum("bi,ij,bj->b", X, inv_cov, X)

# 6) threshold: top 2% as out (same as ZINC)
out_percent = 2.0
thr = float(np.percentile(md_tr, 100.0 - out_percent))

print("Mahalanobis stats (train): min/mean/max =", md_tr.min(), md_tr.mean(), md_tr.max())
print(f"thr (top {out_percent}%):", thr)
print("train out-rate:", np.mean(md_tr > thr))

# 7) save artifacts (numpy-only)
np.save(os.path.join(ad_dir, "md_mu.npy"), mu)
np.save(os.path.join(ad_dir, "md_inv_cov.npy"), inv_cov)
np.save(os.path.join(ad_dir, "md_thr.npy"), np.array([thr]))

np.save(os.path.join(ad_dir, "md_median.npy"), median)
np.save(os.path.join(ad_dir, "md_iqr.npy"), iqr_safe)

print("Saved:")
print(" - md_mu.npy / md_inv_cov.npy / md_thr.npy")
print(" - md_median.npy / md_iqr.npy")

G_tr shape: (43468, 1024)
Mahalanobis stats (train): min/mean/max = 30.10624160556471 124.61417260630242 1962.6205267782075
thr (top 2.0%): 316.6998296135209
train out-rate: 0.020014723474740037
Saved:
 - md_mu.npy / md_inv_cov.npy / md_thr.npy
 - md_median.npy / md_iqr.npy


In [8]:
# テストデータに対するAD判定
# ---- load AD artifacts (NREL) ----
ad_dir = "models/ad_artifacts/nrel"

mu = np.load(os.path.join(ad_dir, "md_mu.npy"))
inv_cov = np.load(os.path.join(ad_dir, "md_inv_cov.npy"))
thr = float(np.load(os.path.join(ad_dir, "md_thr.npy"))[0])

mean  = np.load(os.path.join(ad_dir, "md_median.npy"))
scale = np.load(os.path.join(ad_dir, "md_iqr.npy"))

# safety
iqr_safe = np.where(scale == 0.0, 1.0, scale)

def mahalanobis_distance_from_g(G, mean, scale, mu, inv_cov):
    """
    G: (N, D) raw embedding
    standardize -> subtract mu -> mahalanobis
    """
    Gs = (G - mean) / scale
    D = Gs - mu.reshape(1, -1)
    md = np.einsum("bi,ij,bj->b", D, inv_cov, D)
    return md

# ---- load test dataframe (NREL) ----
test_csv_path = "data/NREL/nrel_test_processed.csv"   # ←ファイル名だけ合わせて
smiles_col = "smile"                                  # ←あなたの列名に合わせて

test_df = pd.read_csv(test_csv_path)

# NRELの目的変数列名
target_cols = predictor.target_cols

print("Targets:", target_cols)

# ---- embed test smiles ----
smiles_list = test_df[smiles_col].astype(str).tolist()

G_list = []
ok_idx = []
bad_idx = []

for i, smi in enumerate(smiles_list):
    g = embed_smiles(predictor, smi)
    if g is None:
        bad_idx.append(i)
        continue
    G_list.append(g)
    ok_idx.append(i)

G_te = np.stack(G_list, axis=0)  # (N_ok, 1024)

# ---- AD score & label ----
md_te = mahalanobis_distance_from_g(G_te, median, iqr_safe, mu, inv_cov)
ad_label = np.where(md_te > thr, "out", "in")

print("TEST embedded:", len(ok_idx), "/", len(test_df), "bad:", len(bad_idx))
print("AD-out rate (test):", np.mean(ad_label == "out"))

# ---- GNN prediction (raw) for same ok samples ----
pred_rows = []
for i in ok_idx:
    smi = smiles_list[i]
    res = predictor.predict_smiles(smi)      # pred_rawを返す実装になっている前提
    pred_rows.append(res["pred_raw"])

pred_df = pd.DataFrame(pred_rows)  # columns: target_cols

# ---- true values (raw) for ok samples ----
# processed.csvに raw が入っている前提。もし scaled しか無いならここで止める。
true_df = test_df.iloc[ok_idx][target_cols].reset_index(drop=True)

# ---- error metrics ----
abs_err = (pred_df[target_cols].values - true_df[target_cols].values)
mae_each = np.mean(np.abs(abs_err), axis=0)   # per-target MAE
mae_macro = np.mean(np.abs(abs_err), axis=1)  # per sample macro-MAE

mask_in = (ad_label == "in")
mask_out = (ad_label == "out")

print("\nMacro-MAE (raw) by AD label")
print("  in : mean =", mae_macro[mask_in].mean(), "n =", int(mask_in.sum()))
print("  out: mean =", mae_macro[mask_out].mean(), "n =", int(mask_out.sum()))

print("\nPer-target MAE (raw) by AD label")
for j, t in enumerate(target_cols):
    print(f"  {t:>10s} | in: {np.mean(np.abs(abs_err[mask_in, j])):.4f} | out: {np.mean(np.abs(abs_err[mask_out, j])):.4f}")

# ---- summary table ----
summary = test_df.iloc[ok_idx][[smiles_col] + target_cols].reset_index(drop=True).copy()
for t in target_cols:
    summary[f"pred_{t}"] = pred_df[t].values
summary["md"] = md_te
summary["ad_label"] = ad_label
summary["mae_macro"] = mae_macro

display(summary.head(10))

Targets: ['gap', 'homo', 'lumo', 'spectral_overlap', 'homo_extrapolated', 'lumo_extrapolated', 'gap_extrapolated', 'optical_lumo_extrapolated']
TEST embedded: 5434 / 5434 bad: 0
AD-out rate (test): 0.020610967979389033

Macro-MAE (raw) by AD label
  in : mean = 31.900016066784897 n = 5322
  out: mean = 41.417406725555345 n = 112

Per-target MAE (raw) by AD label
         gap | in: 0.0778 | out: 0.1125
        homo | in: 0.0571 | out: 0.0965
        lumo | in: 0.0662 | out: 0.1079
  spectral_overlap | in: 254.6899 | out: 330.6000
  homo_extrapolated | in: 0.0734 | out: 0.1134
  lumo_extrapolated | in: 0.0751 | out: 0.0914
  gap_extrapolated | in: 0.0883 | out: 0.1133
  optical_lumo_extrapolated | in: 0.0722 | out: 0.1042


Unnamed: 0,smile,gap,homo,lumo,spectral_overlap,homo_extrapolated,lumo_extrapolated,gap_extrapolated,optical_lumo_extrapolated,pred_gap,pred_homo,pred_lumo,pred_spectral_overlap,pred_homo_extrapolated,pred_lumo_extrapolated,pred_gap_extrapolated,pred_optical_lumo_extrapolated,md,ad_label,mae_macro
0,CN1C(=O)c2cn(C)c(-c3cc4c(s3)c3sc(-c5cc6c(ccc7n...,1.9039,-5.058052,-2.896652,2891.449538,-4.927982,-3.174208,1.4155,-3.512482,1.883357,-5.048064,-2.851986,2881.379974,-4.82899,-3.134547,1.367921,-3.462917,82.689315,in,1.29757
1,Cn1c(=O)c2cc(F)c3c(=O)n(C)c(=O)c4c(-c5scc6c5C(...,1.9486,-5.993579,-3.588365,279.676231,-5.940245,-3.635169,1.872,-4.068245,2.120014,-6.247611,-3.511788,274.469807,-6.07417,-3.61637,1.894259,-4.184565,97.256361,in,0.749968
2,Cn1cc2c(c1-c1sc3c(c(F)c(F)c4c(C(F)(F)F)csc43)c...,3.3019,-5.529081,-1.778808,211.826445,-5.380507,-2.481406,2.5025,-2.878007,3.078325,-5.484511,-1.881059,247.634294,-5.369314,-2.405537,2.596289,-2.774814,92.462092,in,4.557786
3,CN1C(=O)c2cc(C(F)(F)F)c3c4c(c(-c5cccc6nsnc56)c...,2.1428,-6.582162,-3.847962,191.814424,-6.417261,-3.859935,2.0126,-4.404661,2.078816,-6.598798,-3.948571,413.752173,-6.45363,-3.929871,2.069413,-4.386049,139.515832,in,27.787588
4,CN1C(=O)C2=C(c3ccc(-c4ccc(-c5ccc(-c6ccc7c(c6)C...,1.8553,-4.672467,-2.64277,10019.222905,-4.65614,-2.658008,1.7825,-2.87364,1.755233,-4.677318,-2.712527,9642.344788,-4.632778,-2.745082,1.594149,-3.038425,135.155273,in,47.189545
5,Cn1cc2c(c1-c1cccc3nc4c5ccccc5c5ccccc5c4nc13)OC...,2.2701,-4.973425,-2.171196,290.695315,-4.294228,-2.264803,1.5951,-2.699128,2.116642,-4.977225,-2.335914,407.860522,-4.42908,-2.453977,1.622339,-2.806809,101.929163,in,14.743266
6,Cc1csc(-c2ccc(-c3cc4c5nn(C)nc5c5cc(-c6cccs6)sc...,2.6844,-4.960363,-2.007928,2302.958874,-4.602261,-2.526033,1.6816,-2.920661,2.635558,-4.95036,-1.990014,2254.460934,-4.590778,-2.545426,1.687141,-2.901878,96.88523,in,6.078738
7,Cc1c2nc(-c3cccs3)n(C)c2c(C)c2c1nc(-c1ccc(-c3cc...,1.2892,-5.277104,-3.704286,1414.824725,-5.2586,-3.825648,1.3164,-3.9422,1.289674,-5.263706,-3.702511,1536.217899,-5.230779,-3.816379,1.249489,-3.982984,146.564151,in,15.194201
8,Cc1nc2c(-c3ccc(-c4cc5c(C)c6sccc6c(C)c5s4)n3C)s...,2.7696,-4.731515,-1.559212,3062.791642,-4.425659,-1.974458,2.0952,-2.330459,2.803644,-4.619782,-1.467004,3308.841276,-4.561262,-1.735517,2.435672,-2.122544,162.475675,in,30.901319
9,C/C=C1\C(=O)c2c(c3cc(-c4ccc(-c5ccc(-c6ccc7c(c6...,2.4775,-5.127713,-2.344805,3580.303419,-5.001997,-2.511339,2.1341,-2.867897,2.490217,-5.114024,-2.319049,3087.596986,-4.966068,-2.477586,2.165468,-2.799376,85.978384,in,61.616021


In [9]:
# train
md_tr = mahalanobis_distance_from_g(G_tr, mean, scale, mu, inv_cov)
print(np.percentile(md_tr, [50, 90, 95, 98, 99]))

# test
print(np.percentile(md_te, [50, 90, 95, 98, 99]))

[103.87428113 207.11319999 251.38964802 316.69982961 367.71003102]
[103.98644558 204.83774985 256.05018122 317.97274284 384.11709076]
