# V7

## Phase 1 (Data)

### Load & Inspect the Dataset (V7 Initial Checks)

In [None]:
import pandas as pd
import numpy as np
import os
import json
from pathlib import Path
from datetime import datetime

# ------------------------------
# Step 1: Setup paths
# ------------------------------
DATA_PATH = Path("v7/data/tox21.csv").resolve()
META_DIR = DATA_PATH.parent / "meta"
META_DIR.mkdir(exist_ok=True)

# ------------------------------
# Step 2: Load the dataset
# ------------------------------
df = pd.read_csv(DATA_PATH)
print(f"✅ Loaded dataset with shape: {df.shape}")

# ------------------------------
# Step 3: Infer column roles
# ------------------------------
# Assume 12 labels, then mol_id and smiles
expected_labels = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',
    'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma',
    'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]

# Confirm labels exist
missing_labels = [c for c in expected_labels if c not in df.columns]
assert not missing_labels, f"❌ Missing label columns: {missing_labels}"

# Detect mol_id and smiles
assert 'mol_id' in df.columns, "❌ Missing 'mol_id' column"
assert 'smiles' in df.columns, "❌ Missing 'smiles' column"

# ------------------------------
# Step 4: Sanity checks
# ------------------------------
print("🔎 Running sanity checks...")

# Check for missing or empty SMILES
num_missing_smiles = df['smiles'].isna().sum()
num_empty_smiles = (df['smiles'].astype(str).str.strip() == "").sum()
assert num_missing_smiles == 0, f"❌ {num_missing_smiles} missing SMILES"
assert num_empty_smiles == 0, f"❌ {num_empty_smiles} empty SMILES"

# Check label values are 0, 1, or NaN
bad_values = {}
for col in expected_labels:
    unique_vals = df[col].dropna().unique()
    bad_vals = [v for v in unique_vals if v not in [0, 1, 0.0, 1.0]]
    if bad_vals:
        bad_values[col] = bad_vals

assert not bad_values, f"❌ Invalid label values detected: {bad_values}"

# Check duplicates
num_dup_mol = df['mol_id'].duplicated().sum()
num_dup_smiles = df['smiles'].duplicated().sum()
print(f"🔁 Duplicates → mol_id: {num_dup_mol}, smiles: {num_dup_smiles}")

# ------------------------------
# Step 5: Save metadata outputs
# ------------------------------
# Preview CSV
df.head(5).to_csv(META_DIR / "preview.csv", index=False)

# Label stats
label_stats = []
for col in expected_labels:
    total = df[col].notna().sum()
    pos = int((df[col] == 1).sum())
    neg = int((df[col] == 0).sum())
    missing = int(df[col].isna().sum())
    prevalence = pos / total if total > 0 else 0
    label_stats.append({
        "label": col,
        "n_samples": total,
        "n_positive": pos,
        "n_negative": neg,
        "n_missing": missing,
        "positive_rate": round(prevalence, 5)
    })

pd.DataFrame(label_stats).to_csv(META_DIR / "label_stats.csv", index=False)

# Schema summary
schema = {
    "timestamp": datetime.utcnow().isoformat() + "Z",
    "n_rows": len(df),
    "n_cols": df.shape[1],
    "mol_id_column": "mol_id",
    "smiles_column": "smiles",
    "label_columns": expected_labels,
    "has_duplicates": {
        "mol_id": bool(num_dup_mol),
        "smiles": bool(num_dup_smiles)
    }
}
with open(META_DIR / "schema.json", "w") as f:
    json.dump(schema, f, indent=2)

print("✅ Sanity checks passed. Metadata saved to:")
print(f"  • Preview: {META_DIR / 'preview.csv'}")
print(f"  • Label stats: {META_DIR / 'label_stats.csv'}")
print(f"  • Schema: {META_DIR / 'schema.json'}")

✅ Loaded dataset with shape: (7831, 14)
🔎 Running sanity checks...
🔁 Duplicates → mol_id: 0, smiles: 0
✅ Sanity checks passed. Metadata saved to:
  • Preview: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\meta\preview.csv
  • Label stats: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\meta\label_stats.csv
  • Schema: D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\meta\schema.json


### 2: RDKit Descriptor Generation and Cleaning

This cell computes RDKit descriptors (~300 per molecule) from SMILES strings, then:
- Drops molecules where descriptor generation fails
- Replaces infinite/extreme values with NaN
- Applies median imputation and StandardScaler
- Saves:
  - Cleaned descriptors → `X_rdkit.npy`
  - Binary labels → `Y.npy`
  - Metadata: `smiles.npy`, `mol_ids.npy`
  - Feature names → `feature_names.txt`
  - Fitted imputer and scaler (for reuse in training/inference)


In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

import numpy as np
import pandas as pd
import joblib
from pathlib import Path

# -------------------------
# Paths
# -------------------------
DATA_PATH = Path("v7/data/tox21.csv").resolve()
META_DIR = DATA_PATH.parent / "meta"
DESC_DIR = DATA_PATH.parent / "descriptors"
DESC_DIR.mkdir(exist_ok=True)

# -------------------------
# Load the dataset
# -------------------------
df = pd.read_csv(DATA_PATH)
expected_labels = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',
    'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma',
    'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]

# -------------------------
# Setup RDKit descriptor calculator (~200–300 descriptors)
# -------------------------
desc_names = [desc[0] for desc in Descriptors._descList]
desc_calculator = MoleculeDescriptors.MolecularDescriptorCalculator(desc_names)

# -------------------------
# Compute descriptors from SMILES
# -------------------------
features, labels = [], []
smiles_list, mol_ids = [], []
failed_count = 0

for i, row in df.iterrows():
    smi = row['smiles']
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        failed_count += 1
        continue
    try:
        desc = desc_calculator.CalcDescriptors(mol)
        features.append(desc)
        labels.append(row[expected_labels].values)
        mol_ids.append(row['mol_id'])
        smiles_list.append(smi)
    except:
        failed_count += 1

print(f"✅ RDKit descriptors computed. Failures: {failed_count}")

X = np.array(features, dtype=np.float64)
Y = np.array(labels).astype(float)
mol_ids = np.array(mol_ids)
smiles_list = np.array(smiles_list)

# -------------------------
# Clean descriptor matrix: replace inf/extreme with NaN
# -------------------------
X[~np.isfinite(X)] = np.nan
X[np.abs(X) > 1e6] = np.nan

n_total_nan = np.isnan(X).sum()
print(f"🔍 Total NaNs after sanitization: {n_total_nan}")

# -------------------------
# Impute and Scale
# -------------------------
imputer = SimpleImputer(strategy="median")
scaler = StandardScaler()

X_imputed = imputer.fit_transform(X)
X_scaled = scaler.fit_transform(X_imputed)

print("✅ RDKit descriptors: Imputed and scaled")
print(f"🧩 Descriptor shape: {X_scaled.shape}")
print(f"🧬 Label shape: {Y.shape}")

# -------------------------
# Save outputs
# -------------------------
np.save(DESC_DIR / "X_rdkit.npy", X_scaled)
np.save(DESC_DIR / "Y.npy", Y)
np.save(DESC_DIR / "smiles.npy", smiles_list)
np.save(DESC_DIR / "mol_ids.npy", mol_ids)

with open(DESC_DIR / "feature_names.txt", "w") as f:
    for name in desc_names:
        f.write(name + "\n")

joblib.dump(imputer, DESC_DIR / "imputer.joblib")
joblib.dump(scaler, DESC_DIR / "scaler.joblib")

print("\n📁 Saved:")
print(f"• Features       → {DESC_DIR / 'X_rdkit.npy'}")
print(f"• Labels         → {DESC_DIR / 'Y.npy'}")
print(f"• SMILES         → {DESC_DIR / 'smiles.npy'}")
print(f"• mol_ids        → {DESC_DIR / 'mol_ids.npy'}")
print(f"• Feature names  → {DESC_DIR / 'feature_names.txt'}")
print(f"• Scaler         → {DESC_DIR / 'scaler.joblib'}")
print(f"• Imputer        → {DESC_DIR / 'imputer.joblib'}")




✅ RDKit descriptors computed. Failures: 0
🔍 Total NaNs after sanitization: 2966
✅ RDKit descriptors: Imputed and scaled
🧩 Descriptor shape: (7831, 208)
🧬 Label shape: (7831, 12)

📁 Saved:
• Features       → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\X_rdkit.npy
• Labels         → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\Y.npy
• SMILES         → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\smiles.npy
• mol_ids        → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\mol_ids.npy
• Feature names  → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\feature_names.txt
• Scaler         → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\scaler.joblib
• Imputer        → D:\Coding Projects\Predicting-

#### 2b) Save *sanitized raw* RDKit descriptors (no scaling)

Recompute RDKit descriptors and only sanitize (replace inf/±inf/extremes with NaN).  
This gives us `X_rdkit_raw.npy` so we can fit imputer/scaler **on the train split only** in the next cell.


In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors

import numpy as np
import pandas as pd
from pathlib import Path

DATA_PATH = Path("v7/data/tox21.csv").resolve()
DESC_DIR = DATA_PATH.parent / "descriptors"
DESC_DIR.mkdir(exist_ok=True)

df = pd.read_csv(DATA_PATH)
expected_labels = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',
    'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma',
    'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]

# RDKit descriptor setup
desc_names = [d[0] for d in Descriptors._descList]
calc = MoleculeDescriptors.MolecularDescriptorCalculator(desc_names)

features, labels, smiles_list, mol_ids = [], [], [], []
fail = 0
for _, row in df.iterrows():
    smi = row["smiles"]
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        fail += 1
        continue
    try:
        desc = calc.CalcDescriptors(mol)
        features.append(desc)
        labels.append(row[expected_labels].values)
        smiles_list.append(smi)
        mol_ids.append(row["mol_id"])
    except:
        fail += 1

print(f"✅ RDKit descriptors recomputed. Failures: {fail}")

X_raw = np.array(features, dtype=np.float64)
Y = np.array(labels, dtype=np.float64)
smiles_arr = np.array(smiles_list)
mol_ids_arr = np.array(mol_ids)

# Sanitize only (no impute/scale)
X_raw[~np.isfinite(X_raw)] = np.nan
X_raw[np.abs(X_raw) > 1e6] = np.nan
print(f"🔍 NaNs after sanitization: {np.isnan(X_raw).sum()} | Shape: {X_raw.shape}")

# Save sanitized raw
np.save(DESC_DIR / "X_rdkit_raw.npy", X_raw)
np.save(DESC_DIR / "Y.npy", Y)  # overwrite same labels to keep in sync
np.save(DESC_DIR / "smiles.npy", smiles_arr)
np.save(DESC_DIR / "mol_ids.npy", mol_ids_arr)
print("📁 Saved sanitized raw →", DESC_DIR / "X_rdkit_raw.npy")




✅ RDKit descriptors recomputed. Failures: 0
🔍 NaNs after sanitization: 2966 | Shape: (7831, 208)
📁 Saved sanitized raw → D:\Coding Projects\Predicting-Drug-Response-Using-Multi-Omics-Data-with-XAI\v7\data\descriptors\X_rdkit_raw.npy


### 3: Deterministic Scaffold-Based Train/Val/Test Split (80/10/10)

This cell:
- Computes RDKit **Bemis–Murcko scaffolds** for each SMILES.
- **Fallback:** If a scaffold is empty (acyclic molecules), uses the molecule’s **canonical SMILES** as a pseudo-scaffold so every molecule is assigned.
- Groups molecules by scaffold and performs an **80/10/10** split **by scaffold**, deterministic and balanced by group size.
- Saves:
  - Index masks: `v7/data/splits/train.npy`, `val.npy`, `test.npy`
  - Metadata: `v7/data/splits/scaffold_split.csv` (mol_id, smiles, scaffold, split)
  - Label distribution per split: `v7/data/splits/label_distribution.csv`
  - Split summary: `v7/data/splits/split_summary.json`

Sanity checks ensure:
- All molecules are assigned and covered exactly once
- Split sizes match the target proportions (±1 due to rounding)


In [None]:
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, OrderedDict
import json
import random

# ------------------------
# Paths & constants
# ------------------------
DATA_DIR = Path("v7/data")
DESC_DIR = DATA_DIR / "descriptors"
SPLIT_DIR = DATA_DIR / "splits"
SPLIT_DIR.mkdir(exist_ok=True)

mol_ids = np.load(DESC_DIR / "mol_ids.npy", allow_pickle=True)
smiles = np.load(DESC_DIR / "smiles.npy", allow_pickle=True)
Y = np.load(DESC_DIR / "Y.npy")  # shape: [N, 12]

# For reporting label-wise stats with names
expected_labels = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',
    'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma',
    'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]

# Split ratios and seed
train_frac, val_frac, test_frac = 0.80, 0.10, 0.10
rng_seed = 42
random.seed(rng_seed)

# ------------------------
# Helper: compute scaffold or fallback to canonical SMILES
# ------------------------
def scaffold_key_from_smiles(smi: str) -> str:
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return "__INVALID__"
    scaf = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=False)
    if not scaf or scaf.strip() == "":
        # Fallback: canonical SMILES as pseudo-scaffold for acyclic molecules
        scaf = Chem.MolToSmiles(mol, isomericSmiles=False)
    return scaf

# ------------------------
# Group indices by scaffold
# ------------------------
scaffold_to_indices = defaultdict(list)
for idx, smi in enumerate(smiles):
    key = scaffold_key_from_smiles(str(smi))
    scaffold_to_indices[key].append(idx)

# Shuffle groups of equal size in a deterministic way
size_to_groups = defaultdict(list)
for scaf, idxs in scaffold_to_indices.items():
    size_to_groups[len(idxs)].append(idxs)

for size in size_to_groups:
    random.shuffle(size_to_groups[size])  # deterministic shuffle due to rng_seed

# Build ordered list of groups: large -> small, with shuffled ties
ordered_groups = []
for size in sorted(size_to_groups.keys(), reverse=True):
    ordered_groups.extend(size_to_groups[size])

# ------------------------
# Allocate groups to splits (80/10/10) by total assigned molecules
# ------------------------
N = len(smiles)
N_grouped = sum(len(g) for g in ordered_groups)
assert N_grouped == N, f"Grouping mismatch: grouped {N_grouped} != total {N}"

target_train = int(round(train_frac * N))
target_val = int(round(val_frac * N))
target_test = N - target_train - target_val  # ensure sums to N

train_idx, val_idx, test_idx = [], [], []
cnt_train = cnt_val = cnt_test = 0

for group in ordered_groups:
    gsize = len(group)

    # Greedy fill towards targets
    if cnt_train + gsize <= target_train:
        train_idx.extend(group); cnt_train += gsize
    elif cnt_val + gsize <= target_val:
        val_idx.extend(group); cnt_val += gsize
    else:
        test_idx.extend(group); cnt_test += gsize

# If due to rounding we still have leftovers to meet exact target sizes, rebalance
remaining = set(range(N)) - set(train_idx) - set(val_idx) - set(test_idx)
if remaining:
    # Fill val, then test, then train (in that order) to hit targets
    for idx in list(remaining):
        if cnt_val < target_val:
            val_idx.append(idx); cnt_val += 1
        elif cnt_test < target_test:
            test_idx.append(idx); cnt_test += 1
        else:
            train_idx.append(idx); cnt_train += 1

# Final sanity checks
all_assigned = set(train_idx) | set(val_idx) | set(test_idx)
assert len(all_assigned) == N, "Not all molecules assigned to a split."
assert len(set(train_idx) & set(val_idx)) == 0, "Overlap between train and val."
assert len(set(train_idx) & set(test_idx)) == 0, "Overlap between train and test."
assert len(set(val_idx) & set(test_idx)) == 0, "Overlap between val and test."

# Sort indices for neatness
train_idx = np.array(sorted(train_idx))
val_idx = np.array(sorted(val_idx))
test_idx = np.array(sorted(test_idx))

print(f"✅ Split complete → Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
print(f"🎯 Targets       → Train: {target_train}, Val: {target_val}, Test: {target_test}")

# ------------------------
# Save index masks
# ------------------------
np.save(SPLIT_DIR / "train.npy", train_idx)
np.save(SPLIT_DIR / "val.npy", val_idx)
np.save(SPLIT_DIR / "test.npy", test_idx)

# ------------------------
# Save metadata CSV
# ------------------------
def compute_scaffolds_for_indices(idxs):
    return [scaffold_key_from_smiles(str(smiles[i])) for i in idxs]

meta_df = pd.DataFrame({
    "idx": np.concatenate([train_idx, val_idx, test_idx]),
    "mol_id": mol_ids[np.concatenate([train_idx, val_idx, test_idx])],
    "smiles": smiles[np.concatenate([train_idx, val_idx, test_idx])],
    "scaffold": compute_scaffolds_for_indices(np.concatenate([train_idx, val_idx, test_idx])),
    "split": (["train"] * len(train_idx)) + (["val"] * len(val_idx)) + (["test"] * len(test_idx)),
})
meta_df.to_csv(SPLIT_DIR / "scaffold_split.csv", index=False)

# ------------------------
# Label distribution per split (positives and non-missing)
# ------------------------
def split_label_stats(Y, idxs, split_name):
    sub = Y[idxs]
    pos = np.nansum(sub == 1, axis=0).astype(int)
    non_missing = np.sum(~np.isnan(sub), axis=0).astype(int)
    prev = np.divide(pos, np.maximum(non_missing, 1))  # avoid div0
    return pd.DataFrame({
        "split": [split_name] * len(expected_labels),
        "label": expected_labels,
        "n_non_missing": non_missing,
        "n_positive": pos,
        "prevalence": np.round(prev, 5),
    })

stats_df = pd.concat([
    split_label_stats(Y, train_idx, "train"),
    split_label_stats(Y, val_idx, "val"),
    split_label_stats(Y, test_idx, "test"),
], ignore_index=True)
stats_df.to_csv(SPLIT_DIR / "label_distribution.csv", index=False)

# ------------------------
# Split summary JSON
# ------------------------
summary = {
    "seed": rng_seed,
    "N_total": int(N),
    "sizes": {
        "train": int(len(train_idx)),
        "val": int(len(val_idx)),
        "test": int(len(test_idx)),
    },
    "targets": {
        "train": int(target_train),
        "val": int(target_val),
        "test": int(target_test),
    },
    "paths": {
        "train_idx": str(SPLIT_DIR / "train.npy"),
        "val_idx": str(SPLIT_DIR / "val.npy"),
        "test_idx": str(SPLIT_DIR / "test.npy"),
        "scaffold_meta": str(SPLIT_DIR / "scaffold_split.csv"),
        "label_distribution": str(SPLIT_DIR / "label_distribution.csv"),
    }
}
(Path(SPLIT_DIR) / "split_summary.json").write_text(json.dumps(summary, indent=2))

print("📁 Saved:")
print(f"• Index masks       → {SPLIT_DIR}")
print(f"• Scaffold metadata → {SPLIT_DIR / 'scaffold_split.csv'}")
print(f"• Label distribution→ {SPLIT_DIR / 'label_distribution.csv'}")
print(f"• Summary JSON      → {SPLIT_DIR / 'split_summary.json'}")




✅ Split complete → Train: 6265, Val: 783, Test: 783
🎯 Targets       → Train: 6265, Val: 783, Test: 783




📁 Saved:
• Index masks       → v7\data\splits
• Scaffold metadata → v7\data\splits\scaffold_split.csv
• Label distribution→ v7\data\splits\label_distribution.csv
• Summary JSON      → v7\data\splits\split_summary.json


### 4: Train-only impute/scale → package train/val/test

- Load `X_rdkit_raw.npy` (sanitized, unscaled)
- Fit `SimpleImputer(median)` and `StandardScaler` **on train only**
- Transform train/val/test
- Save:
  - NPZ bundles: `v7/data/prepared/{train,val,test}.npz`
  - Train-fitted artifacts: `imputer_train.joblib`, `scaler_train.joblib`
  - `dataset_manifest.json`
- Sanity checks: coverage, overlap, finiteness, label validity


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
import json
import joblib

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

DATA_DIR = Path("v7/data")
DESC_DIR = DATA_DIR / "descriptors"
SPLIT_DIR = DATA_DIR / "splits"
PREP_DIR  = DATA_DIR / "prepared"
PREP_DIR.mkdir(exist_ok=True)

# Load sanitized RAW descriptors and metadata
X_raw   = np.load(DESC_DIR / "X_rdkit_raw.npy")   # sanitized, may contain NaNs
Y       = np.load(DESC_DIR / "Y.npy")
smiles  = np.load(DESC_DIR / "smiles.npy", allow_pickle=True)
mol_ids = np.load(DESC_DIR / "mol_ids.npy", allow_pickle=True)

train_idx = np.load(SPLIT_DIR / "train.npy")
val_idx   = np.load(SPLIT_DIR / "val.npy")
test_idx  = np.load(SPLIT_DIR / "test.npy")

N, F = X_raw.shape
assert Y.shape[0] == N == smiles.shape[0] == mol_ids.shape[0], "Row mismatch across arrays."

# --- Sanity: coverage & overlap
assigned = set(train_idx.tolist() + val_idx.tolist() + test_idx.tolist())
assert len(assigned) == N, f"Not all rows assigned: {len(assigned)} != {N}"
assert not (set(train_idx) & set(val_idx)),  "Overlap train/val"
assert not (set(train_idx) & set(test_idx)), "Overlap train/test"
assert not (set(val_idx) & set(test_idx)),   "Overlap val/test"

# --- Fit imputer & scaler on TRAIN ONLY
imputer_tr = SimpleImputer(strategy="median")
scaler_tr  = StandardScaler()

X_tr_imp = imputer_tr.fit_transform(X_raw[train_idx])
X_tr     = scaler_tr.fit_transform(X_tr_imp)

def transform_split(indices: np.ndarray, name: str):
    X_imp = imputer_tr.transform(X_raw[indices])
    X_s   = scaler_tr.transform(X_imp)
    Y_s   = Y[indices].astype(np.float32, copy=False)
    s_s   = smiles[indices]
    m_s   = mol_ids[indices]
    y_missing = np.isnan(Y_s)
    # checks
    assert np.isfinite(X_s).all(), f"Non-finite features in {name} split after transform."
    return X_s.astype(np.float32), Y_s, s_s, m_s, y_missing

X_train, Y_train, smi_train, mid_train, miss_train = transform_split(train_idx, "train")
X_val,   Y_val,   smi_val,   mid_val,   miss_val   = transform_split(val_idx,   "val")
X_test,  Y_test,  smi_test,  mid_test,  miss_test  = transform_split(test_idx,  "test")

print("✅ Train-only impute/scale complete.")
print("Shapes →",
      "train", X_train.shape, 
      "val",   X_val.shape, 
      "test",  X_test.shape)

# --- Save bundles
def save_npz(name, Xs, Ys, smi, mids, idxs, miss):
    out = PREP_DIR / f"{name}.npz"
    np.savez_compressed(out, X=Xs, Y=Ys, smiles=smi, mol_id=mids, indices=idxs, y_missing_mask=miss)
    return str(out)

p_train = save_npz("train", X_train, Y_train, smi_train, mid_train, train_idx, miss_train)
p_val   = save_npz("val",   X_val,   Y_val,   smi_val,   mid_val,   val_idx,   miss_val)
p_test  = save_npz("test",  X_test,  Y_test,  smi_test,  mid_test,  test_idx,  miss_test)

# --- Save train-fitted artifacts
joblib.dump(imputer_tr, DESC_DIR / "imputer_train.joblib")
joblib.dump(scaler_tr,  DESC_DIR / "scaler_train.joblib")

# --- Manifest
expected_labels = [
    'NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase',
    'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma',
    'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'
]
manifest = {
    "timestamp": datetime.utcnow().isoformat() + "Z",
    "n_total": int(N),
    "n_features": int(F),
    "labels": expected_labels,
    "artifacts": {
        "prepared_dir": str(PREP_DIR),
        "train_npz": p_train, "val_npz": p_val, "test_npz": p_test,
        "imputer_train": str(DESC_DIR / "imputer_train.joblib"),
        "scaler_train": str(DESC_DIR / "scaler_train.joblib"),
        "splits_dir": str(SPLIT_DIR),
        "descriptors_dir": str(DESC_DIR),
        "raw_descriptors": str(DESC_DIR / "X_rdkit_raw.npy"),
    },
    "splits": {
        "train": {"size": int(len(train_idx))},
        "val":   {"size": int(len(val_idx))},
        "test":  {"size": int(len(test_idx))}
    }
}
(PREP_DIR / "dataset_manifest.json").write_text(json.dumps(manifest, indent=2))
print("📁 Saved bundles & manifest →", PREP_DIR)


✅ Train-only impute/scale complete.
Shapes → train (6265, 208) val (783, 208) test (783, 208)
📁 Saved bundles & manifest → v7\data\prepared


## Phase 2 (Model)

### 1 : Cross-Attention Fusion Core + Sanity Test

This defines the **novel fusion** block:
- SMILES token embeddings (queries) attend over graph node embeddings (keys/values)
- Residual + LayerNorm
- Masked mean pooling for text and graph streams (fixed broadcasting)
- Descriptor MLP to align RDKit features with the model dimension
- Classifier head → 12 toxicity logits

A synthetic sanity test checks tensor shapes and masks end-to-end (no encoders yet).


In [None]:
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# Repro & device
# -----------------------------
def seed_everything(seed: int = 42):
    import random, numpy as np
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load descriptor dim & labels from manifest
# -----------------------------
PREP_DIR = Path("v7/data/prepared")
manifest_path = PREP_DIR / "dataset_manifest.json"
assert manifest_path.exists(), f"Missing manifest at {manifest_path}. Run Phase 1, Cell 4 first."

with open(manifest_path) as f:
    manifest = json.load(f)

LABEL_NAMES = manifest["labels"]
N_LABELS = len(LABEL_NAMES)
DESC_IN_DIM = manifest["n_features"]  # RDKit feature count (e.g., 208)

print(f"Loaded manifest. N_LABELS={N_LABELS}, DESC_IN_DIM={DESC_IN_DIM}")

# -----------------------------
# Mask helpers (fixed broadcasting)
# -----------------------------
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    """
    x:    (B, L, D)
    mask: (B, L) with 1 for valid, 0 for pad
    returns: (B, D)
    """
    # ensure same device/dtype
    mask = mask.to(dtype=x.dtype, device=x.device)
    # keepdim=True to make denom shape (B,1), so it broadcasts over D
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)  # (B,1)
    num = (x * mask.unsqueeze(-1)).sum(dim=dim)              # (B,D)
    return num / denom                                       # (B,D)

def lengths_from_mask(mask: torch.Tensor) -> torch.Tensor:
    return mask.long().sum(dim=1)

# -----------------------------
# Modules
# -----------------------------
class DescriptorMLP(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: int = 256, p: float = 0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden, out_dim),
            nn.GELU(),
            nn.Dropout(p),
        )

    def forward(self, x):  # (B, in_dim)
        return self.net(x) # (B, out_dim)

class CrossAttentionBlock(nn.Module):
    """
    Single cross-attention layer: text queries attend to graph keys/values.
    Uses PyTorch MultiheadAttention. Expects masks for graph nodes.
    """
    def __init__(self, dim: int, n_heads: int = 4, p: float = 0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=p, batch_first=False)
        self.dropout = nn.Dropout(p)
        self.ln = nn.LayerNorm(dim)

    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        """
        text_tokens: (B, L, D)
        text_mask:   (B, L)  1=valid, 0=pad
        graph_nodes: (B, N, D)
        graph_mask:  (B, N)  1=valid, 0=pad
        Returns:
          text_out: (B, L, D) after cross-attn + residual + LN
        """
        B, L, D = text_tokens.shape
        N = graph_nodes.size(1)

        # Convert to (S, B, D) for MHA
        Q = text_tokens.transpose(0, 1)   # (L, B, D)
        K = graph_nodes.transpose(0, 1)   # (N, B, D)
        V = graph_nodes.transpose(0, 1)   # (N, B, D)

        # key_padding_mask: (B, N) with True for positions to ignore
        key_padding_mask = (graph_mask == 0)  # bool
        attn_out, _ = self.mha(Q, K, V, key_padding_mask=key_padding_mask)  # (L, B, D)

        # Residual + LN
        attn_out = attn_out.transpose(0, 1)   # (B, L, D)
        text_out = self.ln(text_tokens + self.dropout(attn_out))
        return text_out

class FusionClassifier(nn.Module):
    """
    Pools attended text and graph streams, fuses with descriptor embedding, predicts 12 labels.
    """
    def __init__(self, dim: int, n_labels: int, p: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p)
        self.mlp = nn.Sequential(
            nn.Linear(dim * 3, dim * 2),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(dim * 2, n_labels),
        )

    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask, desc_embed):
        # Masked mean pools (fixed)
        text_pool = masked_mean(text_tokens, text_mask, dim=1)   # (B, D)
        graph_pool = masked_mean(graph_nodes, graph_mask, dim=1) # (B, D)

        fused = torch.cat([text_pool, graph_pool, desc_embed], dim=-1)  # (B, 3D)
        fused = self.dropout(fused)
        logits = self.mlp(fused)  # (B, n_labels)
        return logits

class V7FusionCore(nn.Module):
    """
    Novel fusion core (no encoders here).
    Expects:
      - text_tokens: (B, L, D)
      - text_mask:   (B, L) 1/0 (valid/pad)
      - graph_nodes: (B, N, D)
      - graph_mask:  (B, N) 1/0
      - desc_feats:  (B, DESC_IN_DIM)
    """
    def __init__(self, dim: int = 256, n_heads: int = 4, n_labels: int = 12,
                 desc_in_dim: int = DESC_IN_DIM, desc_hidden: int = 256,
                 p: float = 0.1):
        super().__init__()
        self.dim = dim
        self.cross = CrossAttentionBlock(dim=dim, n_heads=n_heads, p=p)
        self.desc_mlp = DescriptorMLP(in_dim=desc_in_dim, out_dim=dim, hidden=desc_hidden, p=p)
        self.classifier = FusionClassifier(dim=dim, n_labels=n_labels, p=p)

    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask, desc_feats):
        text_attn = self.cross(text_tokens, text_mask, graph_nodes, graph_mask)  # (B,L,D)
        desc_embed = self.desc_mlp(desc_feats)                                   # (B,D)
        logits = self.classifier(text_attn, text_mask, graph_nodes, graph_mask, desc_embed)  # (B, n_labels)
        return logits

# -----------------------------
# 🔎 Sanity test with synthetic tensors (no encoders yet)
# -----------------------------
def _sanity_test():
    B, L, N, D = 4, 64, 48, 256
    desc_in = DESC_IN_DIM

    text_tokens = torch.randn(B, L, D, device=device)
    graph_nodes = torch.randn(B, N, D, device=device)

    # Build masks with at least 1 valid token/node per item
    text_mask = (torch.rand(B, L, device=device) > 0.1).int()
    graph_mask = (torch.rand(B, N, device=device) > 0.1).int()
    for b in range(B):
        if text_mask[b].sum() == 0:
            text_mask[b, 0] = 1
        if graph_mask[b].sum() == 0:
            graph_mask[b, 0] = 1

    desc_feats = torch.randn(B, desc_in, device=device)

    model = V7FusionCore(dim=256, n_heads=4, n_labels=N_LABELS, desc_in_dim=desc_in, desc_hidden=256, p=0.1).to(device)
    with torch.no_grad():
        logits = model(text_tokens, text_mask, graph_nodes, graph_mask, desc_feats)
    print(f"[Sanity] logits shape: {tuple(logits.shape)} (expected: {B} x {N_LABELS})")

_sanity_test()
print("✅ Fusion core defined & sanity-checked.")


Loaded manifest. N_LABELS=12, DESC_IN_DIM=208
[Sanity] logits shape: (4, 12) (expected: 4 x 12)
✅ Fusion core defined & sanity-checked.


### 2: ChemBERTa Text Encoder Wrapper (+ config saved under `v7/model`)

This cell defines a lightweight wrapper around a ChemBERTa checkpoint to produce
token-level embeddings aligned to the fusion dimension (default **256**).

**What it does**
- Loads a SMILES-aware tokenizer & model (ChemBERTa: `seyonec/ChemBERTa-zinc-base-v1`)
- Projects hidden size → `fusion_dim` (256) for compatibility with the fusion core
- Returns:
  - `text_tokens`: `(B, L, 256)` token embeddings
  - `text_mask`: `(B, L)` with 1=valid, 0=pad (compatible with fusion core)
- Utilities:
  - `freeze_backbone(n_unfrozen_layers=0)` for staged fine-tuning
  - Optional gradient checkpointing toggle
- Saves a minimal **encoder manifest** to `v7/model/text_encoder/config.json`
- Runs a **small sanity test** using a few SMILES from your prepared train split

> Notes:
> - Default max length = **256 tokens**; adjust via `max_length` when calling.
> - Ensure `transformers` is installed (>=4.30 recommended).


In [None]:
import json
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn

# Prefer the "seyonec" ChemBERTa (stable & widely used)
from transformers import AutoTokenizer, AutoModel

# -----------------------------
# Paths & constants
# -----------------------------
MODEL_DIR = Path("v7/model")
TEXT_DIR  = MODEL_DIR / "text_encoder"
TEXT_DIR.mkdir(parents=True, exist_ok=True)

PREP_DIR = Path("v7/data/prepared")
train_npz = PREP_DIR / "train.npz"
assert train_npz.exists(), "Missing train split. Please run Phase 1, Cells 3–4."

# -----------------------------
# Config you can tweak
# -----------------------------
CHEMBERTA_CKPT = "seyonec/ChemBERTa-zinc-base-v1"   # or: "DeepChem/ChemBERTa-77M-MLM"
FUSION_DIM     = 256
DROPOUT_PROB   = 0.1
MAX_SEQ_LEN    = 256    # default when encoding batches

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# Text encoder wrapper
# -----------------------------
class ChemBERTaEncoder(nn.Module):
    """
    Wraps a SMILES-aware ChemBERTa to produce token embeddings aligned to fusion dim.
    Returns (text_tokens, text_mask):
      - text_tokens: (B, L, FUSION_DIM)
      - text_mask:   (B, L) int {0,1}
    """
    def __init__(
        self,
        ckpt_name: str = CHEMBERTA_CKPT,
        fusion_dim: int = FUSION_DIM,
        dropout_p: float = DROPOUT_PROB,
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        self.ckpt_name = ckpt_name

        # Tokenizer/Model
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)

        hidden_size = self.backbone.config.hidden_size
        self.proj = nn.Sequential(
            nn.Dropout(dropout_p),
            nn.Linear(hidden_size, fusion_dim),
        )
        self.ln = nn.LayerNorm(fusion_dim)

        if gradient_checkpointing and hasattr(self.backbone, "gradient_checkpointing_enable"):
            self.backbone.gradient_checkpointing_enable()

        # Keep mask semantics explicit: 1=valid, 0=pad
        self.pad_token_id = self.tokenizer.pad_token_id
        if self.pad_token_id is None:
            # Some Roberta tokenizers don't have pad by default; set to eos
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.pad_token_id = self.tokenizer.pad_token_id

    @torch.no_grad()
    def encode(
        self,
        smiles_list: List[str],
        max_length: int = MAX_SEQ_LEN,
        add_special_tokens: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Fast no-grad encode → (text_tokens, text_mask)
        """
        self.eval()
        return self.forward(smiles_list, max_length, add_special_tokens)

    def forward(
        self,
        smiles_list: List[str],
        max_length: int = MAX_SEQ_LEN,
        add_special_tokens: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass → (text_tokens, text_mask)
        """
        enc = self.tokenizer(
            list(smiles_list),
            padding=True,
            truncation=True,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            return_tensors="pt",
        )

        input_ids      = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)  # 1=valid, 0=pad

        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state  # (B, L, H)

        tokens = self.proj(last_hidden)          # (B, L, fusion_dim)
        tokens = self.ln(tokens)                 # (B, L, fusion_dim)

        # Return mask as int {0,1} to match fusion core expectations
        mask = attention_mask.to(dtype=torch.int32)

        return tokens, mask

    def freeze_backbone(self, n_unfrozen_layers: int = 0) -> None:
        """
        Freeze full backbone; optionally unfreeze the last `n_unfrozen_layers` transformer blocks.
        """
        # Freeze all backbone parameters
        for p in self.backbone.parameters():
            p.requires_grad = False

        if n_unfrozen_layers > 0:
            assert hasattr(self.backbone, "encoder") or hasattr(self.backbone, "roberta"), \
                "Unexpected backbone structure; adjust unfreezing logic."
            # For RoBERTa-like models in HF, layers live under .encoder.layer (or .roberta.encoder.layer)
            encoder = getattr(self.backbone, "encoder", None)
            if encoder is None and hasattr(self.backbone, "roberta"):
                encoder = self.backbone.roberta.encoder

            if encoder is not None and hasattr(encoder, "layer"):
                L = len(encoder.layer)
                for idx in range(L - n_unfrozen_layers, L):
                    for p in encoder.layer[idx].parameters():
                        p.requires_grad = True

        # Always keep projection & layernorm trainable
        for p in self.proj.parameters():
            p.requires_grad = True
        for p in self.ln.parameters():
            p.requires_grad = True

# -----------------------------
# Build & save a minimal manifest
# -----------------------------
text_encoder = ChemBERTaEncoder(
    ckpt_name=CHEMBERTA_CKPT,
    fusion_dim=FUSION_DIM,
    dropout_p=DROPOUT_PROB,
    gradient_checkpointing=False,  # set True if you need memory savings
).to(device)

manifest = {
    "checkpoint": CHEMBERTA_CKPT,
    "fusion_dim": FUSION_DIM,
    "dropout_p": DROPOUT_PROB,
    "max_seq_len_default": MAX_SEQ_LEN,
    "pad_token_id": text_encoder.pad_token_id,
    "hidden_size": int(text_encoder.backbone.config.hidden_size),
    "device": str(device),
}
(TEXT_DIR / "config.json").write_text(json.dumps(manifest, indent=2))
print("📝 Saved text encoder manifest →", TEXT_DIR / "config.json")

# -----------------------------
# 🔎 Quick sanity test on real SMILES from train split
# -----------------------------
batch = np.load(train_npz, allow_pickle=True)
sample_smiles = [str(s) for s in batch["smiles"][:4].tolist()]  # small batch of real strings

with torch.no_grad():
    toks, mask = text_encoder.encode(sample_smiles, max_length=MAX_SEQ_LEN)

print("Sanity:")
print("  input batch:", len(sample_smiles))
print("  tokens:", tuple(toks.shape), "(B, L, 256)")
print("  mask:  ", tuple(mask.shape), "(B, L), 1=valid, 0=pad")

Using device: cuda
📝 Saved text encoder manifest → v7\model\text_encoder\config.json
Sanity:
  input batch: 4
  tokens: (4, 45, 256) (B, L, 256)
  mask:   (4, 45) (B, L), 1=valid, 0=pad


### 3: Graph Encoder (RDKit → GIN)

This cell builds a lightweight GIN graph encoder without external GNN libs:
- RDKit featurisation → padded tensors
- Pure-PyTorch GIN layers (sum aggregation + MLP + residual + LayerNorm)
- Outputs per-node embeddings (dim=256) + masks, ready for fusion

Saves:
- `v7/model/graph_encoder/config.json`


In [None]:
import json
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from rdkit import Chem

# -----------------------------
# Paths & constants
# -----------------------------
MODEL_DIR = Path("v7/model")
GRAPH_DIR = MODEL_DIR / "graph_encoder"
GRAPH_DIR.mkdir(parents=True, exist_ok=True)

PREP_DIR = Path("v7/data/prepared")
train_npz = PREP_DIR / "train.npz"
assert train_npz.exists(), "Missing train split. Please run Phase 1 Cells 3–4."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

FUSION_DIM   = 256
MAX_NODES    = 128   # cap for very large molecules (rare in Tox21)
DROPOUT_PROB = 0.1
GIN_LAYERS   = 4
GIN_HIDDEN   = 256   # keep equal to fusion dim

# -----------------------------
# Atom featurisation
# -----------------------------
# Common organic set; everything else -> "other"
ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
HYB_LIST  = [
    Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2,
]
CHIRAL_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER,
]

def one_hot(value, choices):
    vec = [0]*len(choices)
    if value in choices:
        vec[choices.index(value)] = 1
    return vec

def clamp_one_hot_int(value: int, lo: int, hi: int) -> List[int]:
    """One-hot for integer value clamped to [lo, hi]; extra bucket if outside range."""
    # buckets: lo..hi and an "other"
    buckets = list(range(lo, hi+1))
    if value < lo or value > hi:
        return [0]*(len(buckets)) + [1]
    out = [0]*(len(buckets)+1)
    out[value - lo] = 1
    return out

def atom_features(atom: Chem.rdchem.Atom) -> List[float]:
    sym = atom.GetSymbol()
    atom_type = one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST + ["other"])

    degree = clamp_one_hot_int(atom.GetDegree(), 0, 5)              # 7 dims (0..5 + other)
    formal = clamp_one_hot_int(atom.GetFormalCharge(), -2, 2)       # 6 dims (-2..2 + other)
    hyb    = one_hot(atom.GetHybridization(), HYB_LIST) + [0]       # +1 "other"
    aromatic = [1 if atom.GetIsAromatic() else 0]
    in_ring  = [1 if atom.IsInRing() else 0]
    chiral   = one_hot(atom.GetChiralTag(), CHIRAL_LIST)

    total_h  = clamp_one_hot_int(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)  # 6 dims
    valence  = clamp_one_hot_int(atom.GetTotalValence(), 0, 5)                     # 7 dims
    mass     = [atom.GetMass() / 200.0]  # scale roughly into [0, 1]

    feat = atom_type + degree + formal + hyb + aromatic + in_ring + chiral + total_h + valence + mass
    return feat

def smiles_to_graph(smi: str, max_nodes: int = MAX_NODES) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns:
      x:   (N, F_node)
      adj: (N, N) binary adjacency (no self loops)
    Truncates to max_nodes if needed.
    """
    mol = Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms() == 0:
        return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)

    # Atom features
    feats = [atom_features(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
    x = np.asarray(feats, dtype=np.float32)

    # Adjacency (no self-loops here; we’ll handle in GIN)
    N = mol.GetNumAtoms()
    adj = np.zeros((N, N), dtype=np.float32)
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        adj[i, j] = 1.0
        adj[j, i] = 1.0

    # Truncate if needed
    if N > max_nodes:
        x = x[:max_nodes]
        adj = adj[:max_nodes, :max_nodes]

    return x, adj

def collate_graphs(smiles_list: List[str], max_nodes: int = MAX_NODES) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Build a padded batch:
      X:     (B, N_max, F_node)
      A:     (B, N_max, N_max)
      mask:  (B, N_max) 1=valid, 0=pad
    """
    graphs = [smiles_to_graph(s, max_nodes) for s in smiles_list]
    N_max = max([g[0].shape[0] for g in graphs] + [1])

    # Node feature dim
    F_node = graphs[0][0].shape[1] if graphs[0][0].size > 0 else len(atom_features(Chem.MolFromSmiles("C").GetAtomWithIdx(0)))

    B = len(graphs)
    X = np.zeros((B, N_max, F_node), dtype=np.float32)
    A = np.zeros((B, N_max, N_max), dtype=np.float32)
    M = np.zeros((B, N_max), dtype=np.int64)

    for i, (x, adj) in enumerate(graphs):
        n = x.shape[0]
        if n == 0:
            continue
        X[i, :n, :] = x
        A[i, :n, :n] = adj
        M[i, :n] = 1

    # to tensors
    X = torch.from_numpy(X)
    A = torch.from_numpy(A)
    M = torch.from_numpy(M)
    return X, A, M

# -----------------------------
# Lightweight GIN (pure torch)
# -----------------------------
class GINLayer(nn.Module):
    def __init__(self, hidden_dim: int, eps_init: float = 0.0, dropout: float = 0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(eps_init, dtype=torch.float32))
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, adj: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x:    (B, N, D)
        adj:  (B, N, N)  binary adjacency (no self loops)
        mask: (B, N)     1=valid, 0=pad
        """
        # Add self term explicitly: (1+eps) * x  +  sum_neighbors(x)
        neigh = torch.matmul(adj, x)  # (B, N, D)
        out = (1.0 + self.eps) * x + neigh
        out = self.mlp(out)

        # Zero-out padded nodes
        out = out * mask.unsqueeze(-1).to(out.dtype)
        return out

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim: int, hidden_dim: int = GIN_HIDDEN, n_layers: int = GIN_LAYERS, dropout: float = DROPOUT_PROB):
        super().__init__()
        self.inp = nn.Sequential(
            nn.Linear(node_in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        self.layers = nn.ModuleList([GINLayer(hidden_dim, eps_init=0.0, dropout=dropout) for _ in range(n_layers)])
        self.out_ln = nn.LayerNorm(hidden_dim)

    @torch.no_grad()
    def encode(self, smiles_list: List[str], max_nodes: int = MAX_NODES) -> Tuple[torch.Tensor, torch.Tensor]:
        self.eval()
        return self.forward(smiles_list, max_nodes)

    def forward(self, smiles_list: List[str], max_nodes: int = MAX_NODES) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
          node_embeddings: (B, N, hidden_dim)
          mask:            (B, N) int {0,1}
        """
        X, A, M = collate_graphs(smiles_list, max_nodes)  # (B,N,F), (B,N,N), (B,N)
        X = X.to(device)
        A = A.to(device)
        M = M.to(device)

        h = self.inp(X)  # (B,N,D)
        for layer in self.layers:
            h = layer(h, A, M)
        h = self.out_ln(h)

        # final mask as int {0,1}
        mask = M.to(dtype=torch.int32)
        return h, mask

# -----------------------------
# Build encoder & save manifest
# -----------------------------
# Infer node feature dim from a simple atom (or compute from a real SMILES)
probe_x, _ = smiles_to_graph("CCO")
node_in_dim = int(probe_x.shape[1]) if probe_x.size > 0 else 64  # fallback

graph_encoder = GraphGINEncoder(node_in_dim=node_in_dim, hidden_dim=FUSION_DIM, n_layers=GIN_LAYERS, dropout=DROPOUT_PROB).to(device)

manifest = {
    "fusion_dim": FUSION_DIM,
    "hidden_dim": FUSION_DIM,
    "n_layers": GIN_LAYERS,
    "dropout_p": DROPOUT_PROB,
    "max_nodes": MAX_NODES,
    "node_in_dim": node_in_dim,
    "atom_types": ATOM_LIST + ["other"],
    "hybridizations": [str(h) for h in HYB_LIST] + ["other"],
    "chiral_types": [int(c) for c in CHIRAL_LIST],
    "device": str(device),
}
(GRAPH_DIR / "config.json").write_text(json.dumps(manifest, indent=2))
print("📝 Saved graph encoder manifest →", GRAPH_DIR / "config.json")

# -----------------------------
# 🔎 Sanity test: 4 real SMILES from train split
# -----------------------------
batch = np.load(train_npz, allow_pickle=True)
sample_smiles = [str(s) for s in batch["smiles"][:4].tolist()]

with torch.no_grad():
    nodes, mask = graph_encoder.encode(sample_smiles, max_nodes=MAX_NODES)

print("Sanity:")
print("  input batch:", len(sample_smiles))
print("  nodes:", tuple(nodes.shape), "(B, N, 256)")
print("  mask: ", tuple(mask.shape), "(B, N), 1=valid, 0=pad")
print("  valid node counts:", mask.sum(dim=1).tolist())

Using device: cuda
📝 Saved graph encoder manifest → v7\model\graph_encoder\config.json
Sanity:
  input batch: 4
  nodes: (4, 21, 256) (B, N, 256)
  mask:  (4, 21) (B, N), 1=valid, 0=pad
  valid node counts: [16, 15, 21, 20]


### 4: Full V7 Fusion Model + Label-Specialist Heads

This cell assembles the full V7 model:
- Text encoder: ChemBERTa (from Cell 2)
- Graph encoder: lightweight GIN (from Cell 3)
- Descriptors branch: MLP (inside fusion)
- Fusion: **cross-attention** (SMILES tokens query graph nodes), masked pooling
- Heads:
  - **Shared multi-label head** (12 logits)
  - **Optionally**: 12 **label-specialist** heads (one-vs-rest heads) for ensembling

What it saves/creates:
- Model manifest: `v7/model/v7_fusion/config.json`
- Ensemble folders (empty, ready for training): `v7/model/ensembles/<LABEL>/`

Sanity test (no training):
- Takes a tiny batch from `v7/data/prepared/train.npz`
- Runs both **shared** and **specialist** forward passes
- Prints tensor shapes and checks for NaNs

> You can freeze/unfreeze encoders via helper methods for staged training.


In [None]:
import os
import json
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn

# Reuse modules from Phase 2 — Cell 1 and Cell 2/3
# Expect these classes/functions already defined in your kernel:
# - ChemBERTaEncoder (Cell 2)
# - GraphGINEncoder  (Cell 3)
# - CrossAttentionBlock, DescriptorMLP, FusionClassifier, masked_mean (Cell 1)
# If you restarted, re-run those cells first.

# -----------------------------
# Paths & constants
# -----------------------------
MODEL_DIR = Path("v7/model")
FUSION_DIR = MODEL_DIR / "v7_fusion"
FUSION_DIR.mkdir(parents=True, exist_ok=True)

ENSEMBLE_DIR = MODEL_DIR / "ensembles"
ENSEMBLE_DIR.mkdir(parents=True, exist_ok=True)

TEXT_DIR = MODEL_DIR / "text_encoder"
GRAPH_DIR = MODEL_DIR / "graph_encoder"

PREP_DIR = Path("v7/data/prepared")
train_npz_path = PREP_DIR / "train.npz"
manifest_path = PREP_DIR / "dataset_manifest.json"
assert train_npz_path.exists() and manifest_path.exists(), "Missing prepared data or manifest."

with open(manifest_path) as f:
    ds_manifest = json.load(f)
LABEL_NAMES = ds_manifest["labels"]
N_LABELS = len(LABEL_NAMES)
DESC_IN_DIM = ds_manifest["n_features"]    # 208 features from RDKit
FUSION_DIM = 256
MAX_SEQ_LEN = 256
MAX_NODES = 128
DROPOUT = 0.1
N_HEADS = 4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# Label-specialist head
# -----------------------------
class LabelHead(nn.Module):
    """Small MLP head for a single label (binary logit). Input: fused (3*D)."""
    def __init__(self, fused_dim: int, hidden: int = 256, p: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(fused_dim, hidden),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden, 1),
        )
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: (B, fused_dim) -> (B, 1)
        return self.net(z)

# -----------------------------
# Full Fusion Model
# -----------------------------
class V7FusionModel(nn.Module):
    """
    Full model that composes:
      - text_encoder: ChemBERTaEncoder
      - graph_encoder: GraphGINEncoder
      - cross-attn + desc MLP + pooling to produce fused vector
      - either shared multi-label head OR 12 specialist heads

    Modes:
      specialist=False (default): shared multi-label head → (B,12)
      specialist=True:  12 label heads (one-vs-rest) → concat → (B,12)
    """
    def __init__(
        self,
        text_encoder: "ChemBERTaEncoder",
        graph_encoder: "GraphGINEncoder",
        desc_in_dim: int = DESC_IN_DIM,
        dim: int = FUSION_DIM,
        n_heads: int = N_HEADS,
        n_labels: int = N_LABELS,
        dropout: float = DROPOUT,
        specialist: bool = False,
    ):
        super().__init__()
        self.text_encoder = text_encoder
        self.graph_encoder = graph_encoder
        self.dim = dim
        self.n_labels = n_labels
        self.specialist = specialist

        # Fusion components (reusing the exact modules from Cell 1 design)
        self.cross = CrossAttentionBlock(dim=dim, n_heads=n_heads, p=dropout)
        self.desc_mlp = DescriptorMLP(in_dim=desc_in_dim, out_dim=dim, hidden=256, p=dropout)

        fused_dim = dim * 3  # [text_pool ; graph_pool ; desc_embed]

        if specialist:
            # 12 label-wise heads
            self.label_heads = nn.ModuleList([LabelHead(fused_dim=fused_dim, hidden=dim, p=dropout) for _ in range(n_labels)])
            self.shared_head = None
        else:
            # Shared multi-label head
            self.shared_head = FusionClassifier(dim=dim, n_labels=n_labels, p=dropout)
            self.label_heads = None

    def freeze_text_backbone(self, n_unfrozen_layers: int = 0):
        """Freeze ChemBERTa backbone; keep proj/LN trainable; optionally unfreeze last N layers."""
        self.text_encoder.freeze_backbone(n_unfrozen_layers=n_unfrozen_layers)

    def freeze_graph(self, freeze: bool = True):
        for p in self.graph_encoder.parameters():
            p.requires_grad = not freeze

    def forward(
        self,
        smiles_list: List[str],
        desc_feats: torch.Tensor,
        max_seq_len: int = MAX_SEQ_LEN,
        max_nodes: int = MAX_NODES,
        return_intermediates: bool = False,
    ) -> Tuple[torch.Tensor, Optional[dict]]:
        """
        Inputs:
          smiles_list: list[str] length B
          desc_feats:  (B, DESC_IN_DIM) float tensor (already imputed/scaled)
        Returns:
          logits: (B, 12)
          intermediates (optional dict)
        """
        # Encode text & graph
        text_tokens, text_mask = self.text_encoder(smiles_list, max_length=max_seq_len)  # (B,L,D), (B,L)
        graph_nodes, graph_mask = self.graph_encoder(smiles_list, max_nodes=max_nodes)    # (B,N,D), (B,N)

        # Ensure tensors on same device
        text_tokens = text_tokens.to(device)
        text_mask   = text_mask.to(device)
        graph_nodes = graph_nodes.to(device)
        graph_mask  = graph_mask.to(device)
        desc_feats  = desc_feats.to(device)

        # Cross-attention update of text tokens with graph context
        text_attn = self.cross(text_tokens, text_mask, graph_nodes, graph_mask)  # (B,L,D)

        # Descriptor embedding
        desc_embed = self.desc_mlp(desc_feats)  # (B,D)

        # Masked pools
        text_pool  = masked_mean(text_attn,   text_mask,  dim=1)  # (B,D)
        graph_pool = masked_mean(graph_nodes, graph_mask, dim=1)  # (B,D)
        fused = torch.cat([text_pool, graph_pool, desc_embed], dim=-1)  # (B, 3D)

        # Heads
        if self.specialist:
            logits_list = [head(fused) for head in self.label_heads]  # list of (B,1)
            logits = torch.cat(logits_list, dim=1)                    # (B,12)
        else:
            logits = self.shared_head(text_attn, text_mask, graph_nodes, graph_mask, desc_embed)  # (B,12)

        aux = None
        if return_intermediates:
            aux = {
                "text_tokens": text_tokens,
                "text_attended": text_attn,
                "graph_nodes": graph_nodes,
                "desc_embed": desc_embed,
                "text_pool": text_pool,
                "graph_pool": graph_pool,
                "fused": fused,
            }
        return logits, aux

# -----------------------------
# Build text/graph encoders from previous cells
# -----------------------------
# These objects should exist if you ran Cell 2 and Cell 3; otherwise re-instantiate:
try:
    text_encoder
except NameError:
    # fallback: rebuild with defaults
    text_encoder = ChemBERTaEncoder().to(device)

try:
    graph_encoder
except NameError:
    # probe node_in_dim like in Cell 3 if needed
    from rdkit import Chem
    def _probe_node_in_dim():
        from rdkit.Chem.Scaffolds import MurckoScaffold
        mol = Chem.MolFromSmiles("CCO")
        from math import isfinite
        return 51  # fallback from previous cell config
    graph_encoder = GraphGINEncoder(node_in_dim=_probe_node_in_dim(), hidden_dim=FUSION_DIM, n_layers=4, dropout=0.1).to(device)

# -----------------------------
# Instantiate both variants (shared & specialist)
# -----------------------------
v7_shared = V7FusionModel(
    text_encoder=text_encoder,
    graph_encoder=graph_encoder,
    desc_in_dim=DESC_IN_DIM,
    dim=FUSION_DIM,
    n_heads=N_HEADS,
    n_labels=N_LABELS,
    dropout=DROPOUT,
    specialist=False,
).to(device)

v7_specialist = V7FusionModel(
    text_encoder=text_encoder,
    graph_encoder=graph_encoder,
    desc_in_dim=DESC_IN_DIM,
    dim=FUSION_DIM,
    n_heads=N_HEADS,
    n_labels=N_LABELS,
    dropout=DROPOUT,
    specialist=True,
).to(device)

# -----------------------------
# Save fusion config & create ensemble folders
# -----------------------------
fusion_manifest = {
    "labels": LABEL_NAMES,
    "n_labels": N_LABELS,
    "desc_in_dim": DESC_IN_DIM,
    "fusion_dim": FUSION_DIM,
    "n_heads": N_HEADS,
    "dropout": DROPOUT,
    "max_seq_len": MAX_SEQ_LEN,
    "max_nodes": MAX_NODES,
    "modes": ["shared", "specialist"],
    "paths": {
        "text_encoder_config": str(TEXT_DIR / "config.json"),
        "graph_encoder_config": str(GRAPH_DIR / "config.json"),
        "ensemble_root": str(ENSEMBLE_DIR),
    }
}
(FUSION_DIR / "config.json").write_text(json.dumps(fusion_manifest, indent=2))
print("📝 Saved fusion manifest →", FUSION_DIR / "config.json")

# Create per-label ensemble directories (empty for now)
for label in LABEL_NAMES:
    (ENSEMBLE_DIR / label).mkdir(parents=True, exist_ok=True)
print("📁 Ensemble label folders ready under:", ENSEMBLE_DIR)

# -----------------------------
# 🔎 Sanity: small forward on real data (no training)
# -----------------------------
batch = np.load(train_npz_path, allow_pickle=True)
smiles_batch = [str(s) for s in batch["smiles"][:4].tolist()]
desc_batch = torch.tensor(batch["X"][:4], dtype=torch.float32, device=device)  # imputed+scaled

with torch.no_grad():
    # Shared head
    logits_shared, aux_shared = v7_shared(smiles_batch, desc_batch, return_intermediates=True)
    # Specialist heads
    logits_spec,  aux_spec   = v7_specialist(smiles_batch, desc_batch, return_intermediates=False)

print("Sanity:")
print("  shared logits:     ", tuple(logits_shared.shape), " (B, 12)")
print("  specialist logits: ", tuple(logits_spec.shape),   " (B, 12)")
assert torch.isfinite(logits_shared).all() and torch.isfinite(logits_spec).all(), "Found non-finite logits."

# Optional peek at fused vector shape (for debugging/ablations)
print("  fused vector shape:", tuple(aux_shared["fused"].shape), " (B, 768)")

Using device: cuda
📝 Saved fusion manifest → v7\model\v7_fusion\config.json
📁 Ensemble label folders ready under: v7\model\ensembles
Sanity:
  shared logits:      (4, 12)  (B, 12)
  specialist logits:  (4, 12)  (B, 12)
  fused vector shape: (4, 768)  (B, 768)


## Phase 3 (Train)

### 0: Hardware & Throughput Probe (fp32 vs AMP)

This cell profiles:
- GPU name/VRAM/CUDA/torch versions
- Full-model forward throughput at batch sizes `[8, 16, 24, 32]`
- Mixed precision (AMP) vs fp32
- A quick fwd+backward step (1 iter) to estimate step time
- Saves results:
  - `v7/results/meta/hw_probe.json`
  - `v7/results/meta/throughput_probe.csv`

**Output:** A recommended batch size & precision mode for training.


In [None]:
import os
import json
import math
import time
from pathlib import Path
from contextlib import nullcontext

import numpy as np
import torch
import torch.nn as nn

# Expect these to be defined from Phase 2:
# - v7_shared (full model w/ shared head)
# - text_encoder, graph_encoder
# If not (e.g., after a restart), we'll try to rebuild minimal defaults.
try:
    v7_shared
except NameError:
    print("⚠️ v7_shared not found in memory — rebuilding minimal defaults.")
    # Minimal rebuild (assumes Phase 2 cells are available; else raise)
    try:
        text_encoder
    except NameError:
        from transformers import AutoTokenizer, AutoModel
        class ChemBERTaEncoder(nn.Module):
            def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
                super().__init__()
                self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
                self.backbone  = AutoModel.from_pretrained(ckpt_name)
                self.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
                self.ln = nn.LayerNorm(fusion_dim)
            def forward(self, smiles_list, max_length=256, add_special_tokens=True):
                enc = self.tokenizer(list(smiles_list), padding=True, truncation=True, max_length=max_length, add_special_tokens=add_special_tokens, return_tensors="pt")
                input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
                out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
                toks = self.ln(self.proj(out))
                return toks, attention_mask.to(dtype=torch.int32)

    try:
        graph_encoder
    except NameError:
        from rdkit import Chem
        # Minimal graph encoder using same interfaces as before
        ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
        def one_hot(v, choices): 
            z=[0]*len(choices); 
            if v in choices: z[choices.index(v)]=1
            return z
        def clamp_oh(v, lo, hi):
            buckets=list(range(lo,hi+1))
            if v<lo or v>hi: return [0]*len(buckets)+[1]
            o=[0]*(len(buckets)+1); o[v-lo]=1; return o
        def atom_features(atom):
            hybs=[Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
            chir=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
            sym = atom.GetSymbol()
            feat = one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
            feat += clamp_oh(atom.GetDegree(),0,5)
            feat += clamp_oh(atom.GetFormalCharge(),-2,2)
            feat += (one_hot(atom.GetHybridization(), hybs)+[0])
            feat += [int(atom.GetIsAromatic())]
            feat += [int(atom.IsInRing())]
            feat += one_hot(atom.GetChiralTag(), chir)
            feat += clamp_oh(atom.GetTotalNumHs(includeNeighbors=True),0,4)
            feat += clamp_oh(atom.GetTotalValence(),0,5)
            feat += [atom.GetMass()/200.0]
            return feat
        def smiles_to_graph(smi, max_nodes=128):
            mol = Chem.MolFromSmiles(smi)
            if mol is None or mol.GetNumAtoms()==0:
                return np.zeros((0,0),dtype=np.float32), np.zeros((0,0),dtype=np.float32)
            feats = [atom_features(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
            x = np.asarray(feats, dtype=np.float32)
            N = mol.GetNumAtoms()
            adj = np.zeros((N,N), dtype=np.float32)
            for b in mol.GetBonds():
                i,j=b.GetBeginAtomIdx(), b.GetEndAtomIdx()
                adj[i,j]=1.0; adj[j,i]=1.0
            if N>128: x=x[:128]; adj=adj[:128,:128]
            return x, adj
        def collate_graphs(smiles_batch):
            graphs=[smiles_to_graph(s) for s in smiles_batch]
            Nmax=max([g[0].shape[0] for g in graphs] + [1])
            Fnode=graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
            B=len(graphs)
            X=np.zeros((B,Nmax,Fnode),dtype=np.float32)
            A=np.zeros((B,Nmax,Nmax),dtype=np.float32)
            M=np.zeros((B,Nmax),dtype=np.int64)
            for i,(x,a) in enumerate(graphs):
                n=x.shape[0]; 
                if n==0: continue
                X[i,:n,:]=x; A[i,:n,:n]=a; M[i,:n]=1
            return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)
        class GINLayer(nn.Module):
            def __init__(self, h=256, p=0.1):
                super().__init__()
                self.eps = nn.Parameter(torch.tensor(0.0))
                self.mlp = nn.Sequential(nn.Linear(h,h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
            def forward(self, x, adj, mask):
                out=(1.0+self.eps)*x + torch.matmul(adj,x)
                out=self.mlp(out)
                return out*mask.unsqueeze(-1).to(out.dtype)
        class GraphGINEncoder(nn.Module):
            def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
                super().__init__()
                self.inp=nn.Sequential(nn.Linear(node_in_dim,hidden_dim), nn.GELU(), nn.Dropout(p))
                self.layers=nn.ModuleList([GINLayer(hidden_dim,p) for _ in range(n_layers)])
                self.ln=nn.LayerNorm(hidden_dim)
            def forward(self, smiles_list, max_nodes=128):
                X,A,M=collate_graphs(smiles_list)
                h=self.inp(X)
                for layer in self.layers: h=layer(h,A,M)
                return self.ln(h), M.to(dtype=torch.int32)

    # Fusion pieces (masked_mean) minimal copy
    def masked_mean(x, mask, dim=1):
        mask = mask.to(dtype=x.dtype, device=x.device)
        denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
        return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

    class CrossAttentionBlock(nn.Module):
        def __init__(self, dim=256, n_heads=4, p=0.1):
            super().__init__()
            self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
            self.ln  = nn.LayerNorm(dim)
            self.do  = nn.Dropout(p)
        def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
            Q=text_tokens.transpose(0,1); K=graph_nodes.transpose(0,1); V=graph_nodes.transpose(0,1)
            kpm=(graph_mask==0)
            attn,_=self.mha(Q,K,V, key_padding_mask=kpm)
            attn=attn.transpose(0,1)
            return self.ln(text_tokens + self.do(attn))

    class DescriptorMLP(nn.Module):
        def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
            super().__init__()
            self.net=nn.Sequential(nn.Linear(in_dim,hidden), nn.GELU(), nn.Dropout(p), nn.Linear(hidden,out_dim), nn.GELU(), nn.Dropout(p))
        def forward(self,x): return self.net(x)

    class FusionClassifier(nn.Module):
        def __init__(self, dim=256, n_labels=12, p=0.1):
            super().__init__()
            self.net = nn.Sequential(nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p), nn.Linear(dim*2, n_labels))
        def forward(self, text_tokens, text_mask, graph_nodes, graph_mask, desc_embed):
            text_pool = masked_mean(text_tokens, text_mask, 1)
            graph_pool= masked_mean(graph_nodes, graph_mask, 1)
            return self.net(torch.cat([text_pool, graph_pool, desc_embed], dim=-1))

    class V7FusionModel(nn.Module):
        def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
            super().__init__()
            self.text_encoder=text_encoder
            self.graph_encoder=graph_encoder
            self.cross=CrossAttentionBlock(dim, n_heads, p)
            self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
            self.shared_head=FusionClassifier(dim, n_labels, p)
        def forward(self, smiles_list, desc_feats, max_seq_len=256, max_nodes=128):
            tt, tm = self.text_encoder(smiles_list, max_length=max_seq_len)
            gn, gm = self.graph_encoder(smiles_list, max_nodes=max_nodes)
            tt, tm, gn, gm, desc_feats = tt.to(device), tm.to(device), gn.to(device), gm.to(device), desc_feats.to(device)
            tta = self.cross(tt, tm, gn, gm)
            de  = self.desc_mlp(desc_feats)
            logits = self.shared_head(tta, tm, gn, gm, de)
            return logits

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    text_encoder = ChemBERTaEncoder().to(device)
    graph_encoder= GraphGINEncoder().to(device)
    v7_shared    = V7FusionModel(text_encoder, graph_encoder).to(device)

# ---------------- Paths & data ----------------
PREP_DIR = Path("v7/data/prepared")
RES_DIR  = Path("v7/results/meta"); RES_DIR.mkdir(parents=True, exist_ok=True)
train_npz = np.load(PREP_DIR / "train.npz", allow_pickle=True)

smiles_train = [str(s) for s in train_npz["smiles"].tolist()]
X_train = torch.tensor(train_npz["X"], dtype=torch.float32)

# ---------------- HW info --------------------
hw = {
    "torch_version": torch.__version__,
    "cuda_available": torch.cuda.is_available(),
    "cuda_version": torch.version.cuda,
    "device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
    "total_vram_gb": round(torch.cuda.get_device_properties(0).total_memory/1024**3,2) if torch.cuda.is_available() else None,
}
(Path(RES_DIR / "hw_probe.json")).write_text(json.dumps(hw, indent=2))
print("HW:", json.dumps(hw, indent=2))

# --------------- benchmark helpers -----------
# --- Patch: unwrap model outputs to logits ---
def _to_logits(out):
    return out[0] if isinstance(out, (tuple, list)) else out

def bench_forward(model, smiles_batch, desc_batch, amp=False, iters=10, warmup=3):
    from contextlib import nullcontext
    model.eval()
    use_amp = amp and torch.cuda.is_available()
    ctx = torch.cuda.amp.autocast(dtype=torch.float16) if use_amp else nullcontext()

    with torch.no_grad(), ctx:
        for _ in range(warmup):
            _ = _to_logits(model(smiles_batch, desc_batch))
    if torch.cuda.is_available(): torch.cuda.synchronize()

    t0 = time.time()
    with torch.no_grad(), ctx:
        for _ in range(iters):
            _ = _to_logits(model(smiles_batch, desc_batch))
    if torch.cuda.is_available(): torch.cuda.synchronize()
    return (time.time() - t0) / iters

def bench_step(model, smiles_batch, desc_batch, amp=False):
    from contextlib import nullcontext
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    use_amp = amp and torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ctx = torch.cuda.amp.autocast(dtype=torch.float16) if use_amp else nullcontext()

    # dummy targets for timing only
    y = torch.rand((desc_batch.size(0), 12), device=desc_batch.device)
    loss_fn = nn.BCEWithLogitsLoss()

    if torch.cuda.is_available(): torch.cuda.synchronize()
    t0 = time.time()
    opt.zero_grad(set_to_none=True)
    with ctx:
        logits = _to_logits(model(smiles_batch, desc_batch))
        loss = loss_fn(logits, y)
    if scaler.is_enabled():
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
    else:
        loss.backward(); opt.step()
    if torch.cuda.is_available(): torch.cuda.synchronize()
    return float(loss.item()), (time.time() - t0)

# --- Re-run the driver loop & save artifacts again ---
batch_sizes = [8, 16, 24, 32]
results = []
max_nodes_seen = 0

for bs in batch_sizes:
    try:
        smiles_bs = smiles_train[:bs]
        X_bs = X_train[:bs].to(next(v7_shared.parameters()).device, non_blocking=True)

        with torch.no_grad():
            gn, gm = graph_encoder(smiles_bs, max_nodes=128)
            max_nodes_seen = max(max_nodes_seen, int(gm.sum(dim=1).max().item()))

        t_f32 = bench_forward(v7_shared, smiles_bs, X_bs, amp=False, iters=10, warmup=3)
        t_amp = bench_forward(v7_shared, smiles_bs, X_bs, amp=True,  iters=10, warmup=3)

        loss_f32, step_f32 = bench_step(v7_shared, smiles_bs, X_bs, amp=False)
        loss_amp,  step_amp = bench_step(v7_shared, smiles_bs, X_bs, amp=True)

        mem_alloc = torch.cuda.max_memory_allocated()/1024**3 if torch.cuda.is_available() else None
        if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats()

        results.append({
            "batch_size": bs,
            "forward_fp32_s_per_batch": round(t_f32, 4),
            "forward_amp_s_per_batch":  round(t_amp, 4),
            "trainstep_fp32_s": round(step_f32, 4),
            "trainstep_amp_s":  round(step_amp, 4),
            "approx_samples_per_s_fp32": round(bs / t_f32, 2),
            "approx_samples_per_s_amp":  round(bs / t_amp, 2),
            "peak_mem_gb": round(mem_alloc, 2) if mem_alloc is not None else None,
        })
        print(f"BS={bs} | fwd fp32 {t_f32:.4f}s | fwd amp {t_amp:.4f}s | step fp32 {step_f32:.4f}s | step amp {step_amp:.4f}s | peak {results[-1]['peak_mem_gb']} GB")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"OOM at batch size {bs}.")
            results.append({"batch_size": bs, "error": "OOM"})
            if torch.cuda.is_available(): torch.cuda.empty_cache()
        else:
            raise

# Save again
df = pd.DataFrame(results)
df.to_csv(RES_DIR / "throughput_probe.csv", index=False)
(Path(RES_DIR / "throughput_probe.json")).write_text(json.dumps(results, indent=2))

valid = [r for r in results if r.get("error") is None]
if valid:
    succ_amp = [r for r in valid if r.get("forward_amp_s_per_batch") is not None]
    if succ_amp:
        rec_bs = max(succ_amp, key=lambda r: r["batch_size"])["batch_size"]
        rec_prec = "amp"
    else:
        rec_bs = max(valid, key=lambda r: r["batch_size"])["batch_size"]
        rec_prec = "fp32"
else:
    rec_bs, rec_prec = 8, "amp"

summary = {
    "hw": hw,
    "max_nodes_seen": int(max_nodes_seen),
    "results": results,
    "recommendation": {
        "batch_size": int(rec_bs),
        "precision": rec_prec,
        "note": "Use grad accumulation to reach higher effective batch if needed."
    }
}
(Path(RES_DIR / "hw_probe_summary.json")).write_text(json.dumps(summary, indent=2))

print("\n=== Probe Summary ===")
print(json.dumps(summary["recommendation"], indent=2))
print(f"Results saved to: {RES_DIR}")

HW: {
  "torch_version": "2.6.0+cu124",
  "cuda_available": true,
  "cuda_version": "12.4",
  "device_name": "NVIDIA GeForce RTX 4070 Ti",
  "total_vram_gb": 11.99
}


  ctx = torch.cuda.amp.autocast(dtype=torch.float16) if use_amp else nullcontext()
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  ctx = torch.cuda.amp.autocast(dtype=torch.float16) if use_amp else nullcontext()


BS=8 | fwd fp32 0.0060s | fwd amp 0.0068s | step fp32 0.0507s | step amp 0.0336s | peak 0.87 GB
BS=16 | fwd fp32 0.0100s | fwd amp 0.0095s | step fp32 0.0349s | step amp 0.0291s | peak 0.88 GB
BS=24 | fwd fp32 0.0134s | fwd amp 0.0125s | step fp32 0.0439s | step amp 0.0400s | peak 0.98 GB
BS=32 | fwd fp32 0.0167s | fwd amp 0.0133s | step fp32 0.0529s | step amp 0.0462s | peak 1.24 GB

=== Probe Summary ===
{
  "batch_size": 32,
  "precision": "amp",
  "note": "Use grad accumulation to reach higher effective batch if needed."
}
Results saved to: v7\results\meta


### 1: Shared-Head Training Loop

This cell trains the **shared multi-label V7 fusion model** with strong settings:

- Windows-safe DataLoaders (`num_workers=0`)
- Batch size 32, **grad accumulation 4** (effective 128)
- **ASL** (gamma_neg=5.0, gamma_pos=1.0) + **class-balanced per-label weights**
- **AMP** mixed precision, **EMA**, cosine LR with warmup
- **Stage A** (8 epochs): ChemBERTa backbone **frozen**
- **Stage B** (20 epochs): unfreeze **last 2** transformer blocks
- Early stopping on **val macro PR-AUC**
- Checkpoints:
  - best → `v7/model/checkpoints/shared/best.pt`
  - per-epoch → `v7/model/checkpoints/shared/epoch_{stage}{epoch:02d}.pt`
- Logs:
  - training rows → `v7/results/logs/train_log.jsonl`
  - val summaries → `v7/results/artifacts/val_metrics.jsonl`
  - run config → `v7/results/meta/train_run_config.json`

> Run this cell to start training. We’ll add plots + specialist heads next.


In [None]:
import os, json, math, time, random, platform
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import average_precision_score, roc_auc_score

# ---------------------------
# Env & reproducibility
# ---------------------------
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # quieter HF tokenizer
def seed_everything(seed=42):
    import numpy as _np, random as _r, torch as _t
    _r.seed(seed); _np.random.seed(seed)
    _t.manual_seed(seed); _t.cuda.manual_seed_all(seed)
    _t.backends.cudnn.deterministic = False
    _t.backends.cudnn.benchmark = True

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------------------
# Paths & setup
# ---------------------------
BASE_DIR   = Path("v7")
DATA_PREP  = BASE_DIR / "data" / "prepared"
RESULTS_DIR= BASE_DIR / "results"
LOGS_DIR   = RESULTS_DIR / "logs"
ARTIF_DIR  = RESULTS_DIR / "artifacts"
META_DIR   = RESULTS_DIR / "meta"
CKPT_DIR   = BASE_DIR / "model" / "checkpoints" / "shared"
for d in [RESULTS_DIR, LOGS_DIR, ARTIF_DIR, META_DIR, CKPT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

with open(DATA_PREP / "dataset_manifest.json") as f:
    ds_manifest = json.load(f)
LABEL_NAMES  = ds_manifest["labels"]
N_LABELS     = len(LABEL_NAMES)
DESC_IN_DIM  = ds_manifest["n_features"]

# Expect v7_shared in memory (from Phase 2 — Cell 4)
try:
    v7_shared
except NameError:
    raise RuntimeError("v7_shared model not found. Please re-run Phase 2 (Cells 1–4) in this kernel.")

# ---------------------------
# Data: Dataset & Loader (Windows-safe)
# ---------------------------
class Tox21NPZDataset(torch.utils.data.Dataset):
    def __init__(self, npz_path: Path):
        b = np.load(npz_path, allow_pickle=True)
        self.X = b["X"].astype(np.float32)             # (N, F)
        self.Y = b["Y"].astype(np.float32)             # (N, 12), may contain NaN
        self.mask_y = b["y_missing_mask"].astype(bool) # True where NaN
        self.smiles = b["smiles"].tolist()
        self.mol_ids= b["mol_id"].tolist()
        assert self.X.shape[0] == self.Y.shape[0] == len(self.smiles)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return {
            "smiles": self.smiles[i],
            "desc": torch.from_numpy(self.X[i]),
            "y": torch.from_numpy(self.Y[i]),
            "y_mask": torch.from_numpy(self.mask_y[i]),
            "mol_id": self.mol_ids[i],
        }

def collate(batch):
    smiles = [b["smiles"] for b in batch]
    desc   = torch.stack([b["desc"] for b in batch], dim=0)
    y      = torch.stack([b["y"] for b in batch], dim=0)
    y_mask = torch.stack([b["y_mask"] for b in batch], dim=0)
    mol_id = [b["mol_id"] for b in batch]
    return {"smiles": smiles, "desc": desc, "y": y, "y_mask": y_mask, "mol_id": mol_id}

train_ds = Tox21NPZDataset(DATA_PREP / "train.npz")
val_ds   = Tox21NPZDataset(DATA_PREP / "val.npz")

IS_WINDOWS = platform.system() == "Windows"
NUM_WORKERS = 0 if IS_WINDOWS else max(0, min(4, (os.cpu_count() or 2)//2))
PERSISTENT = False

# Batch & precision (from your probe)
BATCH_SIZE  = 32
GRAD_ACCUM  = 4      # effective batch 128
USE_AMP     = True

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"),
    collate_fn=collate, drop_last=True,
    persistent_workers=(PERSISTENT if NUM_WORKERS>0 else False),
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"),
    collate_fn=collate,
    persistent_workers=(PERSISTENT if NUM_WORKERS>0 else False),
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

# ---------------------------
# Label stats & weights
# ---------------------------
def compute_label_stats(y: np.ndarray, mask_missing: np.ndarray) -> Dict[str, np.ndarray]:
    valid = ~mask_missing
    pos = np.nansum((y == 1) & valid, axis=0)
    neg = np.nansum((y == 0) & valid, axis=0)
    total = pos + neg
    prevalence = np.divide(pos, np.maximum(total, 1))
    return {"pos": pos.astype(int), "neg": neg.astype(int), "total": total.astype(int), "prevalence": prevalence}

train_blob = np.load(DATA_PREP / "train.npz", allow_pickle=True)
train_stats = compute_label_stats(train_blob["Y"], train_blob["y_missing_mask"])
print("Train label prevalence:", np.round(train_stats["prevalence"], 3))

def effective_number_weights(pos_counts: np.ndarray, beta: float = 0.999) -> np.ndarray:
    eps = 1e-8
    eff_num = (1 - np.power(beta, pos_counts + eps)) / (1 - beta)
    alpha = 1.0 / np.maximum(eff_num, eps)
    alpha = alpha / (np.mean(alpha) + eps)
    return alpha.astype(np.float32)

alpha_cb = effective_number_weights(train_stats["pos"])
print("Class-balanced alpha:", np.round(alpha_cb, 3))
alpha_cb_t = torch.tensor(alpha_cb, device=device)

# ---------------------------
# Loss: Asymmetric Loss (stronger)
# ---------------------------
class AsymmetricLossCB(nn.Module):
    def __init__(self, gamma_neg=5.0, gamma_pos=1.0, clip=0.05, alpha: Optional[torch.Tensor]=None):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.alpha = alpha  # (n_labels,) on device

    def forward(self, logits: torch.Tensor, targets: torch.Tensor, missing_mask: torch.Tensor):
        # mask out missing labels
        valid = ~missing_mask
        if not valid.any():
            return logits.new_tensor(0.0)

        logits_v  = logits[valid]
        targets_v = targets[valid]

        pred = torch.sigmoid(logits_v)
        if self.clip:
            pred = torch.clamp(pred, self.clip, 1 - self.clip)

        anti_targets = 1 - targets_v
        pt = pred * targets_v + (1 - pred) * anti_targets
        one_sided_gamma = self.gamma_pos * targets_v + self.gamma_neg * anti_targets
        focal_weight = torch.pow(1 - pt, one_sided_gamma)

        loss = - (targets_v * torch.log(pred) + (1 - targets_v) * torch.log(1 - pred))
        loss = loss * focal_weight

        if self.alpha is not None:
            n_labels = logits.size(1)
            label_idx_full = torch.arange(n_labels, device=logits.device).unsqueeze(0).expand_as(missing_mask)
            label_idx_valid = label_idx_full[valid]
            alpha_vec = self.alpha[label_idx_valid]
            loss = loss * alpha_vec

        return loss.mean()

criterion = AsymmetricLossCB(gamma_neg=5.0, gamma_pos=1.0, clip=0.05, alpha=alpha_cb_t)

# ---------------------------
# Optimizer, EMA, Scheduler
# ---------------------------
def build_optimizer(model, lr=3e-4, wd=1e-2):
    return optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=lr, weight_decay=wd)

class EMA:
    def __init__(self, model: nn.Module, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for n,p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n] = p.detach().clone()
    @torch.no_grad()
    def update(self, model: nn.Module):
        for n,p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
    @torch.no_grad()
    def copy_to(self, model: nn.Module):
        for n,p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                p.data.copy_(self.shadow[n])

def build_warmup_cosine(total_steps, warmup_ratio=0.1, min_lr_scale=0.1):
    def lr_lambda(step):
        warm = int(total_steps * warmup_ratio)
        if step < warm:
            return float(step) / max(1, warm)
        progress = (step - warm) / max(1, total_steps - warm)
        cosine = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr_scale + (1 - min_lr_scale) * cosine
    return lr_lambda

# ---------------------------
# Metrics (macro PR-AUC primary)
# ---------------------------
def eval_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
    pr_aucs, roc_aucs = [], []
    for j in range(y_true.shape[1]):
        mask = ~np.isnan(y_true[:, j])
        if mask.sum() < 2 or len(np.unique(y_true[mask, j])) < 2:
            pr_aucs.append(np.nan); roc_aucs.append(np.nan); continue
        try:
            pr_aucs.append(average_precision_score(y_true[mask, j], y_prob[mask, j]))
        except Exception:
            pr_aucs.append(np.nan)
        try:
            roc_aucs.append(roc_auc_score(y_true[mask, j], y_prob[mask, j]))
        except Exception:
            roc_aucs.append(np.nan)
    return {"macro_pr_auc": float(np.nanmean(pr_aucs)),
            "macro_roc_auc": float(np.nanmean(roc_aucs))}

# ---------------------------
# Stage settings (stronger)
# ---------------------------
MAX_EPOCHS_STAGE_A = 8
MAX_EPOCHS_STAGE_B = 20
UNFREEZE_LAST_N    = 2
EARLY_STOP_PATIENCE= 5
EMA_DECAY          = 0.999
LR                 = 3e-4
WEIGHT_DECAY       = 1e-2
WARMUP_RATIO       = 0.1
CLIP_NORM          = 1.0

# Freeze text backbone for Stage A
v7_shared.freeze_text_backbone(n_unfrozen_layers=0)
v7_shared.freeze_graph(freeze=False)  # keep graph trainable

# Save run config for reproducibility
run_cfg = {
    "batch_size": BATCH_SIZE,
    "grad_accum": GRAD_ACCUM,
    "use_amp": USE_AMP,
    "max_epochs": {"stage_a": MAX_EPOCHS_STAGE_A, "stage_b": MAX_EPOCHS_STAGE_B},
    "optimizer": {"lr": LR, "weight_decay": WEIGHT_DECAY},
    "scheduler": {"type": "warmup_cosine", "warmup_ratio": WARMUP_RATIO},
    "asl": {"gamma_neg": 5.0, "gamma_pos": 1.0, "clip": 0.05},
    "ema_decay": EMA_DECAY,
    "unfreeze_last_n": UNFREEZE_LAST_N,
    "early_stop_patience": EARLY_STOP_PATIENCE,
    "clip_norm": CLIP_NORM,
    "num_workers": NUM_WORKERS,
}
(META_DIR / "train_run_config.json").write_text(json.dumps(run_cfg, indent=2))

# ---------------------------
# AMP scaler
# ---------------------------
scaler = torch.amp.GradScaler("cuda", enabled=(USE_AMP and device.type=="cuda"))

# ---------------------------
# One epoch
# ---------------------------
def run_epoch(model: nn.Module, loader, optimizer, scheduler, ema: Optional[EMA], stage_name="train"):
    is_train = optimizer is not None
    model.train(mode=is_train)

    running_loss = 0.0
    steps = 0
    all_probs, all_true, all_mask = [], [], []

    for step, batch in enumerate(loader):
        smiles = batch["smiles"]
        desc   = batch["desc"].to(device, non_blocking=True)
        y      = batch["y"].to(device, non_blocking=True)
        y_mask = batch["y_mask"].to(device, non_blocking=True)

        ctx = torch.amp.autocast("cuda", dtype=torch.float16) if (USE_AMP and device.type=="cuda") else torch.autocast(device_type="cpu", enabled=False)
        with ctx:
            logits, _ = v7_shared(smiles, desc, return_intermediates=False)  # (B,12)
            loss = criterion(logits, y, y_mask)

        if is_train:
            loss = loss / GRAD_ACCUM
            scaler.scale(loss).backward()

            if (step + 1) % GRAD_ACCUM == 0:
                # gradient clipping
                if CLIP_NORM is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(v7_shared.parameters(), CLIP_NORM)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                if scheduler is not None:
                    scheduler.step()
                if ema is not None:
                    ema.update(model)

        running_loss += float(loss.item()) * GRAD_ACCUM
        steps += 1

        with torch.no_grad():
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            y_np  = y.detach().cpu().numpy()
            m_np  = y_mask.detach().cpu().numpy()
            all_probs.append(probs); all_true.append(y_np); all_mask.append(m_np)

    all_probs = np.concatenate(all_probs, axis=0)
    all_true  = np.concatenate(all_true, axis=0)
    all_mask  = np.concatenate(all_mask, axis=0)
    all_true  = np.where(all_mask, np.nan, all_true)  # ignore missing labels

    metrics = eval_metrics(all_true, all_probs)
    return running_loss / max(1, steps), metrics

# ---------------------------
# Orchestrate both stages with early stopping
# ---------------------------
def save_epoch_ckpt(stage_tag: str, epoch: int, ema: EMA, path: Path):
    torch.save({
        "model": v7_shared.state_dict(),
        "ema": ema.shadow,
        "config": {"stage": stage_tag, "epoch": epoch},
    }, path)

def train_shared_model():
    best_metric = -1.0
    best_path   = CKPT_DIR / "best.pt"
    log_path    = LOGS_DIR / "train_log.jsonl"
    val_log_path= ARTIF_DIR / "val_metrics.jsonl"
    for p in [log_path, val_log_path]:
        if p.exists(): p.unlink()

    # ===== Stage A: warm-up (text frozen) =====
    print("\n=== Stage A: Warm-up (text backbone frozen) ===")
    opt = build_optimizer(v7_shared, lr=LR, wd=WEIGHT_DECAY)
    total_steps = math.ceil(len(train_loader) / GRAD_ACCUM) * MAX_EPOCHS_STAGE_A
    scheduler = LambdaLR(opt, build_warmup_cosine(total_steps, warmup_ratio=WARMUP_RATIO))
    ema = EMA(v7_shared, decay=EMA_DECAY)
    patience = EARLY_STOP_PATIENCE

    for epoch in range(1, MAX_EPOCHS_STAGE_A + 1):
        t0 = time.time()
        tr_loss, tr_metrics = run_epoch(v7_shared, train_loader, opt, scheduler, ema, stage_name="train")
        with torch.no_grad():
            ema.copy_to(v7_shared)
            va_loss, va_metrics = run_epoch(v7_shared, val_loader, optimizer=None, scheduler=None, ema=None, stage_name="val")

        row = {
            "stage": "A", "epoch": epoch,
            "train_loss": tr_loss, **{f"train_{k}": v for k,v in tr_metrics.items()},
            "val_loss": va_loss, **{f"val_{k}": v for k,v in va_metrics.items()},
            "epoch_time_s": round(time.time()-t0, 2),
        }
        with open(log_path, "a") as f: f.write(json.dumps(row) + "\n")
        with open(val_log_path, "a") as f: f.write(json.dumps({"stage":"A", "epoch": epoch, **va_metrics}) + "\n")
        print(row)

        # Save per-epoch checkpoint
        save_epoch_ckpt("A", epoch, ema, CKPT_DIR / f"epoch_A{epoch:02d}.pt")

        score = va_metrics["macro_pr_auc"]
        if score > best_metric:
            best_metric = score
            torch.save({"model": v7_shared.state_dict(), "ema": ema.shadow, "config": {"stage":"A"}}, best_path)
            patience = EARLY_STOP_PATIENCE
            print(f"  ✅ New best (Stage A) macro PR-AUC: {best_metric:.4f} → checkpoint saved.")
        else:
            patience -= 1
            if patience <= 0:
                print("  ⏹ Early stop in Stage A.")
                break

    # ===== Stage B: finetune (unfreeze last N) =====
    print("\n=== Stage B: Finetune (unfreeze last 2 ChemBERTa layers) ===")
    v7_shared.freeze_text_backbone(n_unfrozen_layers=UNFREEZE_LAST_N)

    opt = build_optimizer(v7_shared, lr=LR * 0.5, wd=WEIGHT_DECAY)
    total_steps = math.ceil(len(train_loader) / GRAD_ACCUM) * MAX_EPOCHS_STAGE_B
    scheduler = LambdaLR(opt, build_warmup_cosine(total_steps, warmup_ratio=WARMUP_RATIO))
    ema = EMA(v7_shared, decay=EMA_DECAY)
    patience = EARLY_STOP_PATIENCE

    for epoch in range(1, MAX_EPOCHS_STAGE_B + 1):
        t0 = time.time()
        tr_loss, tr_metrics = run_epoch(v7_shared, train_loader, opt, scheduler, ema, stage_name="train")
        with torch.no_grad():
            ema.copy_to(v7_shared)
            va_loss, va_metrics = run_epoch(v7_shared, val_loader, optimizer=None, scheduler=None, ema=None, stage_name="val")

        row = {
            "stage": "B", "epoch": epoch,
            "train_loss": tr_loss, **{f"train_{k}": v for k,v in tr_metrics.items()},
            "val_loss": va_loss, **{f"val_{k}": v for k,v in va_metrics.items()},
            "epoch_time_s": round(time.time()-t0, 2),
        }
        with open(log_path, "a") as f: f.write(json.dumps(row) + "\n")
        with open(val_log_path, "a") as f: f.write(json.dumps({"stage":"B", "epoch": epoch, **va_metrics}) + "\n")
        print(row)

        # Save per-epoch checkpoint
        save_epoch_ckpt("B", epoch, ema, CKPT_DIR / f"epoch_B{epoch:02d}.pt")

        score = va_metrics["macro_pr_auc"]
        if score > best_metric:
            best_metric = score
            torch.save({"model": v7_shared.state_dict(), "ema": ema.shadow, "config": {"stage":"B"}}, best_path)
            patience = EARLY_STOP_PATIENCE
            print(f"  ✅ New best (Stage B) macro PR-AUC: {best_metric:.4f} → checkpoint saved.")
        else:
            patience -= 1
            if patience <= 0:
                print("  ⏹ Early stop in Stage B.")
                break

    print(f"\n🎯 Training complete. Best macro PR-AUC: {best_metric:.4f} | Best ckpt: {best_path}")
    print(f"Logs → {LOGS_DIR/'train_log.jsonl'}")
    print(f"Val metrics → {ARTIF_DIR/'val_metrics.jsonl'}")
    print(f"Checkpoints → {CKPT_DIR}")

# ---------------------------
# Kick it off
# ---------------------------
train_shared_model()

Device: cuda
Train batches: 195, Val batches: 25
Train label prevalence: [0.047 0.038 0.128 0.051 0.135 0.055 0.027 0.165 0.037 0.056 0.167 0.061]
Class-balanced alpha: [1.077 1.385 0.53  1.232 0.53  0.985 1.995 0.48  1.367 1.022 0.479 0.919]

=== Stage A: Warm-up (text backbone frozen) ===




{'stage': 'A', 'epoch': 1, 'train_loss': 0.035231313520135026, 'train_macro_pr_auc': 0.15394772991859337, 'train_macro_roc_auc': 0.6320955267390821, 'val_loss': 0.16247534304857253, 'val_macro_pr_auc': 0.060688107343346155, 'val_macro_roc_auc': 0.4460254985837288, 'epoch_time_s': 5.02}
  ✅ New best (Stage A) macro PR-AUC: 0.0607 → checkpoint saved.




{'stage': 'A', 'epoch': 2, 'train_loss': 0.03287650308547876, 'train_macro_pr_auc': 0.2207926681933469, 'train_macro_roc_auc': 0.7010044653688422, 'val_loss': 0.1518526303768158, 'val_macro_pr_auc': 0.07123903216471082, 'val_macro_roc_auc': 0.48648823084936604, 'epoch_time_s': 4.77}
  ✅ New best (Stage A) macro PR-AUC: 0.0712 → checkpoint saved.




{'stage': 'A', 'epoch': 3, 'train_loss': 0.03286819191506276, 'train_macro_pr_auc': 0.2186634655024724, 'train_macro_roc_auc': 0.7048009936671273, 'val_loss': 0.14330618798732758, 'val_macro_pr_auc': 0.0828460714633456, 'val_macro_roc_auc': 0.5262641855361049, 'epoch_time_s': 4.74}
  ✅ New best (Stage A) macro PR-AUC: 0.0828 → checkpoint saved.




{'stage': 'A', 'epoch': 4, 'train_loss': 0.0327842725870701, 'train_macro_pr_auc': 0.21544627600788266, 'train_macro_roc_auc': 0.7084898348195994, 'val_loss': 0.13847267180681228, 'val_macro_pr_auc': 0.10282287714227774, 'val_macro_roc_auc': 0.5629917021669618, 'epoch_time_s': 4.75}
  ✅ New best (Stage A) macro PR-AUC: 0.1028 → checkpoint saved.




{'stage': 'A', 'epoch': 5, 'train_loss': 0.032721771142230585, 'train_macro_pr_auc': 0.217635243463428, 'train_macro_roc_auc': 0.7083933660695082, 'val_loss': 0.13518619030714035, 'val_macro_pr_auc': 0.12165443303340044, 'val_macro_roc_auc': 0.5874946897864975, 'epoch_time_s': 4.8}
  ✅ New best (Stage A) macro PR-AUC: 0.1217 → checkpoint saved.




{'stage': 'A', 'epoch': 6, 'train_loss': 0.03289635310379358, 'train_macro_pr_auc': 0.20782836636565902, 'train_macro_roc_auc': 0.7021579445824658, 'val_loss': 0.13288636565208434, 'val_macro_pr_auc': 0.13197449463464273, 'val_macro_roc_auc': 0.6008043797614259, 'epoch_time_s': 4.82}
  ✅ New best (Stage A) macro PR-AUC: 0.1320 → checkpoint saved.




{'stage': 'A', 'epoch': 7, 'train_loss': 0.033106887522034154, 'train_macro_pr_auc': 0.20516994149759463, 'train_macro_roc_auc': 0.6974444627266895, 'val_loss': 0.13136810392141343, 'val_macro_pr_auc': 0.13953489453248116, 'val_macro_roc_auc': 0.6104861465660406, 'epoch_time_s': 4.73}
  ✅ New best (Stage A) macro PR-AUC: 0.1395 → checkpoint saved.




{'stage': 'A', 'epoch': 8, 'train_loss': 0.033446282559098345, 'train_macro_pr_auc': 0.19800331001972962, 'train_macro_roc_auc': 0.6793394336970406, 'val_loss': 0.13041243970394134, 'val_macro_pr_auc': 0.14403884369407427, 'val_macro_roc_auc': 0.6171125914864409, 'epoch_time_s': 4.75}
  ✅ New best (Stage A) macro PR-AUC: 0.1440 → checkpoint saved.

=== Stage B: Finetune (unfreeze last 2 ChemBERTa layers) ===




{'stage': 'B', 'epoch': 1, 'train_loss': 0.033351336682263096, 'train_macro_pr_auc': 0.20295440301231638, 'train_macro_roc_auc': 0.6865177545804902, 'val_loss': 0.12948915004730224, 'val_macro_pr_auc': 0.14860819775385747, 'val_macro_roc_auc': 0.6244653691925603, 'epoch_time_s': 5.99}
  ✅ New best (Stage B) macro PR-AUC: 0.1486 → checkpoint saved.




{'stage': 'B', 'epoch': 2, 'train_loss': 0.0320538672976769, 'train_macro_pr_auc': 0.2409792810558551, 'train_macro_roc_auc': 0.7315953583704417, 'val_loss': 0.12786235362291337, 'val_macro_pr_auc': 0.15892553431027442, 'val_macro_roc_auc': 0.638336387427298, 'epoch_time_s': 5.89}
  ✅ New best (Stage B) macro PR-AUC: 0.1589 → checkpoint saved.




{'stage': 'B', 'epoch': 3, 'train_loss': 0.031573516235519676, 'train_macro_pr_auc': 0.25877013362413176, 'train_macro_roc_auc': 0.7455078833992309, 'val_loss': 0.1264248350262642, 'val_macro_pr_auc': 0.16992076406025358, 'val_macro_roc_auc': 0.6545090711061684, 'epoch_time_s': 5.95}
  ✅ New best (Stage B) macro PR-AUC: 0.1699 → checkpoint saved.




{'stage': 'B', 'epoch': 4, 'train_loss': 0.03169219841559728, 'train_macro_pr_auc': 0.2535870353864961, 'train_macro_roc_auc': 0.7403325204304396, 'val_loss': 0.12525496780872344, 'val_macro_pr_auc': 0.17670751058022716, 'val_macro_roc_auc': 0.664954602299176, 'epoch_time_s': 5.94}
  ✅ New best (Stage B) macro PR-AUC: 0.1767 → checkpoint saved.




{'stage': 'B', 'epoch': 5, 'train_loss': 0.0309824329872544, 'train_macro_pr_auc': 0.2783762235540493, 'train_macro_roc_auc': 0.7612776006180176, 'val_loss': 0.12415830075740814, 'val_macro_pr_auc': 0.18558657990269578, 'val_macro_roc_auc': 0.6738228522983359, 'epoch_time_s': 6.02}
  ✅ New best (Stage B) macro PR-AUC: 0.1856 → checkpoint saved.




{'stage': 'B', 'epoch': 6, 'train_loss': 0.031206133656012706, 'train_macro_pr_auc': 0.27829100621751796, 'train_macro_roc_auc': 0.755299053293554, 'val_loss': 0.1231931608915329, 'val_macro_pr_auc': 0.19293831224054891, 'val_macro_roc_auc': 0.6810440461639281, 'epoch_time_s': 5.92}
  ✅ New best (Stage B) macro PR-AUC: 0.1929 → checkpoint saved.




{'stage': 'B', 'epoch': 7, 'train_loss': 0.031034588914078017, 'train_macro_pr_auc': 0.2810264936485967, 'train_macro_roc_auc': 0.7623333892543859, 'val_loss': 0.1224514576792717, 'val_macro_pr_auc': 0.1996031092407037, 'val_macro_roc_auc': 0.6881710678940766, 'epoch_time_s': 5.98}
  ✅ New best (Stage B) macro PR-AUC: 0.1996 → checkpoint saved.




{'stage': 'B', 'epoch': 8, 'train_loss': 0.030679168958121384, 'train_macro_pr_auc': 0.29096576640307903, 'train_macro_roc_auc': 0.7672944482343568, 'val_loss': 0.12180023938417435, 'val_macro_pr_auc': 0.20525084567553417, 'val_macro_roc_auc': 0.6943129844375328, 'epoch_time_s': 5.98}
  ✅ New best (Stage B) macro PR-AUC: 0.2053 → checkpoint saved.




{'stage': 'B', 'epoch': 9, 'train_loss': 0.030749009086344488, 'train_macro_pr_auc': 0.286702373022353, 'train_macro_roc_auc': 0.769002218430392, 'val_loss': 0.12127423912286758, 'val_macro_pr_auc': 0.21246464164262072, 'val_macro_roc_auc': 0.6987706221166011, 'epoch_time_s': 5.92}
  ✅ New best (Stage B) macro PR-AUC: 0.2125 → checkpoint saved.




{'stage': 'B', 'epoch': 10, 'train_loss': 0.03075902950591766, 'train_macro_pr_auc': 0.2865881383643247, 'train_macro_roc_auc': 0.7706060849873775, 'val_loss': 0.12079172015190125, 'val_macro_pr_auc': 0.21125299214080603, 'val_macro_roc_auc': 0.7026318061256759, 'epoch_time_s': 6.08}




{'stage': 'B', 'epoch': 11, 'train_loss': 0.03074433825050409, 'train_macro_pr_auc': 0.2866102209753219, 'train_macro_roc_auc': 0.767944986693523, 'val_loss': 0.12046253174543381, 'val_macro_pr_auc': 0.2160076280215174, 'val_macro_roc_auc': 0.7052656884511049, 'epoch_time_s': 5.99}
  ✅ New best (Stage B) macro PR-AUC: 0.2160 → checkpoint saved.




{'stage': 'B', 'epoch': 12, 'train_loss': 0.03064479235177621, 'train_macro_pr_auc': 0.29587938734080016, 'train_macro_roc_auc': 0.7747545667274607, 'val_loss': 0.1201474142074585, 'val_macro_pr_auc': 0.21815972659022878, 'val_macro_roc_auc': 0.7084575461077836, 'epoch_time_s': 6.23}
  ✅ New best (Stage B) macro PR-AUC: 0.2182 → checkpoint saved.




{'stage': 'B', 'epoch': 13, 'train_loss': 0.03051889530645731, 'train_macro_pr_auc': 0.3014124548450356, 'train_macro_roc_auc': 0.7754224122799486, 'val_loss': 0.11987470299005508, 'val_macro_pr_auc': 0.22021944101515215, 'val_macro_roc_auc': 0.7103294667602594, 'epoch_time_s': 6.24}
  ✅ New best (Stage B) macro PR-AUC: 0.2202 → checkpoint saved.




{'stage': 'B', 'epoch': 14, 'train_loss': 0.03055528616771484, 'train_macro_pr_auc': 0.30036792174015225, 'train_macro_roc_auc': 0.7770674850742396, 'val_loss': 0.11969551861286164, 'val_macro_pr_auc': 0.2204572735624223, 'val_macro_roc_auc': 0.7117377024199903, 'epoch_time_s': 6.02}
  ✅ New best (Stage B) macro PR-AUC: 0.2205 → checkpoint saved.




{'stage': 'B', 'epoch': 15, 'train_loss': 0.030464942605258564, 'train_macro_pr_auc': 0.3004765438036446, 'train_macro_roc_auc': 0.777889729862784, 'val_loss': 0.11952996402978897, 'val_macro_pr_auc': 0.22635373887372232, 'val_macro_roc_auc': 0.7136480190734761, 'epoch_time_s': 5.93}
  ✅ New best (Stage B) macro PR-AUC: 0.2264 → checkpoint saved.




{'stage': 'B', 'epoch': 16, 'train_loss': 0.030600831323327162, 'train_macro_pr_auc': 0.2948815446342145, 'train_macro_roc_auc': 0.7766768345697753, 'val_loss': 0.11940244868397713, 'val_macro_pr_auc': 0.22757343295334945, 'val_macro_roc_auc': 0.7147992546108844, 'epoch_time_s': 5.95}
  ✅ New best (Stage B) macro PR-AUC: 0.2276 → checkpoint saved.




{'stage': 'B', 'epoch': 17, 'train_loss': 0.0306711228707662, 'train_macro_pr_auc': 0.29773487707514557, 'train_macro_roc_auc': 0.7771101181372019, 'val_loss': 0.11931319043040275, 'val_macro_pr_auc': 0.2284273391126127, 'val_macro_roc_auc': 0.7155680955816607, 'epoch_time_s': 5.92}
  ✅ New best (Stage B) macro PR-AUC: 0.2284 → checkpoint saved.




{'stage': 'B', 'epoch': 18, 'train_loss': 0.030714609569463973, 'train_macro_pr_auc': 0.29627989863976345, 'train_macro_roc_auc': 0.7806201270680816, 'val_loss': 0.11927412450313568, 'val_macro_pr_auc': 0.23048184628064763, 'val_macro_roc_auc': 0.7158172353803324, 'epoch_time_s': 5.94}
  ✅ New best (Stage B) macro PR-AUC: 0.2305 → checkpoint saved.




{'stage': 'B', 'epoch': 19, 'train_loss': 0.030844925281902153, 'train_macro_pr_auc': 0.29332838681471207, 'train_macro_roc_auc': 0.7770113536906883, 'val_loss': 0.11922193139791488, 'val_macro_pr_auc': 0.23035495448753893, 'val_macro_roc_auc': 0.7166446565031673, 'epoch_time_s': 5.98}




{'stage': 'B', 'epoch': 20, 'train_loss': 0.030895793051100694, 'train_macro_pr_auc': 0.28553649395747516, 'train_macro_roc_auc': 0.7788041242291293, 'val_loss': 0.11917937502264976, 'val_macro_pr_auc': 0.2308804674395504, 'val_macro_roc_auc': 0.7172780654968326, 'epoch_time_s': 6.15}
  ✅ New best (Stage B) macro PR-AUC: 0.2309 → checkpoint saved.

🎯 Training complete. Best macro PR-AUC: 0.2309 | Best ckpt: v7\model\checkpoints\shared\best.pt
Logs → v7\results\logs\train_log.jsonl
Val metrics → v7\results\artifacts\val_metrics.jsonl
Checkpoints → v7\model\checkpoints\shared


#### 1b) fine-tune the shared model

In [None]:
# === Phase 3 — Cell 1c: Optional extended finetune (Stage C) ===
import json, math, time
from pathlib import Path
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

BASE_DIR   = Path("v7")
DATA_PREP  = BASE_DIR / "data" / "prepared"
RESULTS_DIR= BASE_DIR / "results"
LOGS_DIR   = RESULTS_DIR / "logs"
ARTIF_DIR  = RESULTS_DIR / "artifacts"
CKPT_DIR   = BASE_DIR / "model" / "checkpoints" / "shared"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reuse objects from previous cell:
# - v7_shared, train_loader, val_loader, criterion, scaler, GRAD_ACCUM, USE_AMP,
#   build_optimizer, EMA, build_warmup_cosine, eval_metrics, run_epoch (already defined)

BEST_PATH = CKPT_DIR / "best.pt"
assert BEST_PATH.exists(), "Best checkpoint not found from Stage B."

# ---- Load best weights (EMA weights baked in via copy_to flow) ----
ckpt = torch.load(BEST_PATH, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)

# ---- Stage C config (slightly smaller LR, unfreeze more) ----
MAX_EPOCHS_STAGE_C   = 15
UNFREEZE_LAST_N_MORE = 4   # total unfreezing depth for Stage C
LR_C                 = 1e-4
WEIGHT_DECAY         = 1e-2
WARMUP_RATIO         = 0.1
EARLY_STOP_PATIENCE  = 5
EMA_DECAY            = 0.999
CLIP_NORM            = 1.0

print("\n=== Stage C: Extended finetune (unfreeze last 4 ChemBERTa layers) ===")
v7_shared.freeze_text_backbone(n_unfrozen_layers=UNFREEZE_LAST_N_MORE)

opt = build_optimizer(v7_shared, lr=LR_C, wd=WEIGHT_DECAY)
total_steps = (len(train_loader) // max(1, GRAD_ACCUM)) * MAX_EPOCHS_STAGE_C
scheduler = LambdaLR(opt, build_warmup_cosine(total_steps, warmup_ratio=WARMUP_RATIO))
ema = EMA(v7_shared, decay=EMA_DECAY)

# Retrieve current best to compare
best_metric = -1.0
try:
    with open(ARTIF_DIR / "val_metrics.jsonl", "r") as f:
        for line in f:
            m = json.loads(line)
            best_metric = max(best_metric, m.get("macro_pr_auc", -1.0))
except FileNotFoundError:
    pass

patience = EARLY_STOP_PATIENCE
for epoch in range(1, MAX_EPOCHS_STAGE_C + 1):
    t0 = time.time()
    tr_loss, tr_metrics = run_epoch(v7_shared, train_loader, opt, scheduler, ema, stage_name="train")
    with torch.no_grad():
        ema.copy_to(v7_shared)
        va_loss, va_metrics = run_epoch(v7_shared, val_loader, optimizer=None, scheduler=None, ema=None, stage_name="val")
    row = {
        "stage": "C", "epoch": epoch,
        "train_loss": tr_loss, **{f"train_{k}": v for k,v in tr_metrics.items()},
        "val_loss": va_loss, **{f"val_{k}": v for k,v in va_metrics.items()},
        "epoch_time_s": round(time.time()-t0, 2),
    }
    with open(LOGS_DIR / "train_log.jsonl", "a") as f: f.write(json.dumps(row) + "\n")
    with open(ARTIF_DIR / "val_metrics.jsonl", "a") as f: f.write(json.dumps({"stage":"C","epoch":epoch, **va_metrics}) + "\n")
    print(row)

    score = va_metrics["macro_pr_auc"]
    if score > best_metric:
        best_metric = score
        torch.save({"model": v7_shared.state_dict(), "ema": ema.shadow, "config": {"stage":"C"}}, BEST_PATH)
        patience = EARLY_STOP_PATIENCE
        print(f"  ✅ New best (Stage C) macro PR-AUC: {best_metric:.4f} → checkpoint saved.")
    else:
        patience -= 1
        if patience <= 0:
            print("  ⏹ Early stop in Stage C.")
            break

print(f"\n🎯 Stage C done. Best macro PR-AUC now: {best_metric:.4f} | Best ckpt: {BEST_PATH}")


=== Stage C: Extended finetune (unfreeze last 4 ChemBERTa layers) ===




{'stage': 'C', 'epoch': 1, 'train_loss': 0.03059412927295153, 'train_macro_pr_auc': 0.2940897781197137, 'train_macro_roc_auc': 0.7738731249518018, 'val_loss': 0.11911626175045967, 'val_macro_pr_auc': 0.2313494757438764, 'val_macro_roc_auc': 0.7179151724089748, 'epoch_time_s': 7.7}
  ✅ New best (Stage C) macro PR-AUC: 0.2313 → checkpoint saved.




{'stage': 'C', 'epoch': 2, 'train_loss': 0.02976563608703705, 'train_macro_pr_auc': 0.32309141373025596, 'train_macro_roc_auc': 0.7923732547730861, 'val_loss': 0.11889535903930665, 'val_macro_pr_auc': 0.23306384630623586, 'val_macro_roc_auc': 0.7202172791739457, 'epoch_time_s': 7.43}
  ✅ New best (Stage C) macro PR-AUC: 0.2331 → checkpoint saved.




{'stage': 'C', 'epoch': 3, 'train_loss': 0.02977958160619705, 'train_macro_pr_auc': 0.3266218849580093, 'train_macro_roc_auc': 0.7896876136921706, 'val_loss': 0.11872626304626464, 'val_macro_pr_auc': 0.2347075514972754, 'val_macro_roc_auc': 0.7224862467323508, 'epoch_time_s': 7.44}
  ✅ New best (Stage C) macro PR-AUC: 0.2347 → checkpoint saved.




{'stage': 'C', 'epoch': 4, 'train_loss': 0.02964736360292404, 'train_macro_pr_auc': 0.33171891965683115, 'train_macro_roc_auc': 0.7927982693500946, 'val_loss': 0.11857076600193978, 'val_macro_pr_auc': 0.2361485101560743, 'val_macro_roc_auc': 0.7242838603441172, 'epoch_time_s': 7.34}
  ✅ New best (Stage C) macro PR-AUC: 0.2361 → checkpoint saved.




{'stage': 'C', 'epoch': 5, 'train_loss': 0.02982027809111736, 'train_macro_pr_auc': 0.31763875566657007, 'train_macro_roc_auc': 0.7874365553179355, 'val_loss': 0.1184364566206932, 'val_macro_pr_auc': 0.23835799454299564, 'val_macro_roc_auc': 0.7257010995233294, 'epoch_time_s': 7.28}
  ✅ New best (Stage C) macro PR-AUC: 0.2384 → checkpoint saved.




{'stage': 'C', 'epoch': 6, 'train_loss': 0.02947875121369576, 'train_macro_pr_auc': 0.3321220359256137, 'train_macro_roc_auc': 0.8006438802397594, 'val_loss': 0.11827784195542336, 'val_macro_pr_auc': 0.23832505697959694, 'val_macro_roc_auc': 0.727440428255865, 'epoch_time_s': 7.23}




{'stage': 'C', 'epoch': 7, 'train_loss': 0.02935470045090486, 'train_macro_pr_auc': 0.3434656936237351, 'train_macro_roc_auc': 0.8023428187366073, 'val_loss': 0.11815975934267044, 'val_macro_pr_auc': 0.2393710062395936, 'val_macro_roc_auc': 0.7287112230799639, 'epoch_time_s': 7.25}
  ✅ New best (Stage C) macro PR-AUC: 0.2394 → checkpoint saved.




{'stage': 'C', 'epoch': 8, 'train_loss': 0.029187886216319524, 'train_macro_pr_auc': 0.3496728510187987, 'train_macro_roc_auc': 0.8075722167410296, 'val_loss': 0.11802344962954521, 'val_macro_pr_auc': 0.240041162238516, 'val_macro_roc_auc': 0.7303042554528408, 'epoch_time_s': 7.22}
  ✅ New best (Stage C) macro PR-AUC: 0.2400 → checkpoint saved.




{'stage': 'C', 'epoch': 9, 'train_loss': 0.029373395767731545, 'train_macro_pr_auc': 0.3405598536923287, 'train_macro_roc_auc': 0.805858050747465, 'val_loss': 0.11795216917991638, 'val_macro_pr_auc': 0.24077963991117687, 'val_macro_roc_auc': 0.7314910443527664, 'epoch_time_s': 7.28}
  ✅ New best (Stage C) macro PR-AUC: 0.2408 → checkpoint saved.




{'stage': 'C', 'epoch': 10, 'train_loss': 0.02929941861388775, 'train_macro_pr_auc': 0.34471098714367926, 'train_macro_roc_auc': 0.8097664124196781, 'val_loss': 0.11790973782539367, 'val_macro_pr_auc': 0.24119206327093037, 'val_macro_roc_auc': 0.7322239918733962, 'epoch_time_s': 7.28}
  ✅ New best (Stage C) macro PR-AUC: 0.2412 → checkpoint saved.




{'stage': 'C', 'epoch': 11, 'train_loss': 0.02930923376041345, 'train_macro_pr_auc': 0.34093325952352727, 'train_macro_roc_auc': 0.8092587990161633, 'val_loss': 0.1178943282365799, 'val_macro_pr_auc': 0.24154978012502534, 'val_macro_roc_auc': 0.7326831399686627, 'epoch_time_s': 7.2}
  ✅ New best (Stage C) macro PR-AUC: 0.2415 → checkpoint saved.




{'stage': 'C', 'epoch': 12, 'train_loss': 0.029186946392441408, 'train_macro_pr_auc': 0.3597753005981579, 'train_macro_roc_auc': 0.8138671743287641, 'val_loss': 0.1178495578467846, 'val_macro_pr_auc': 0.2416392331610726, 'val_macro_roc_auc': 0.7333118541081923, 'epoch_time_s': 7.23}
  ✅ New best (Stage C) macro PR-AUC: 0.2416 → checkpoint saved.




{'stage': 'C', 'epoch': 13, 'train_loss': 0.029342801329226065, 'train_macro_pr_auc': 0.3456173786396817, 'train_macro_roc_auc': 0.8124665169572586, 'val_loss': 0.1178448085486889, 'val_macro_pr_auc': 0.24113489174898617, 'val_macro_roc_auc': 0.7337164190361283, 'epoch_time_s': 7.21}




{'stage': 'C', 'epoch': 14, 'train_loss': 0.029542968245461966, 'train_macro_pr_auc': 0.34503395319943714, 'train_macro_roc_auc': 0.8082895448255402, 'val_loss': 0.11786730423569679, 'val_macro_pr_auc': 0.24131952642285667, 'val_macro_roc_auc': 0.733945058303203, 'epoch_time_s': 7.21}




{'stage': 'C', 'epoch': 15, 'train_loss': 0.029530828427045772, 'train_macro_pr_auc': 0.34638721395507566, 'train_macro_roc_auc': 0.8117984862586809, 'val_loss': 0.11786027416586876, 'val_macro_pr_auc': 0.2411037459323543, 'val_macro_roc_auc': 0.7342894549870769, 'epoch_time_s': 7.17}

🎯 Stage C done. Best macro PR-AUC now: 0.2416 | Best ckpt: v7\model\checkpoints\shared\best.pt


### 2 (boosted): Label-Specialist Heads (extended + multi-seed)

We train one-vs-rest heads for each of the 12 Tox21 labels:

- Cached fused features (from the shared model): `v7/data/fused/`
- Head MLP: 768 → 512 → 256 → 128 → 1 (GELU + LayerNorm + Dropout 0.30, residual)
- Loss: **Binary ASL** (γ⁻=5.0, γ⁺=1.0) with **class-balanced α**
- Sampler: **class-balanced WeightedRandomSampler** to ensure positives per batch
- Optimiser & schedule: AdamW (lr=3e-3) + **cosine warmup** (10% warmup)
- Regularisation: **EMA**, AMP, grad-clip (L2-norm=1.0), WD=1e-2
- Early stopping: patience=10 on **val AP**
- **5 seeds** per label (can change via `SEEDS`)

Artifacts:
- `v7/model/ensembles/<LABEL>/seedXX/best.pt`
- `v7/model/ensembles/<LABEL>/seedXX/val_preds.npz` (for calibration)
- `v7/model/ensembles/ensemble_summary.json`

> Expect ~30–50 mins on a 4070 Ti (depends on your desktop load). You can lower `SEEDS` or `EPOCHS_MAX` if needed.


In [None]:
# === Phase 3 — Cell 2 (boosted) ===
import os, json, math, time, random, platform
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import average_precision_score

# ------------------------------
# Paths & globals
# ------------------------------
BASE_DIR     = Path("v7")
DATA_PREP    = BASE_DIR / "data" / "prepared"
FUSED_DIR    = BASE_DIR / "data" / "fused"
FUSED_DIR.mkdir(parents=True, exist_ok=True)
ENSEMBLE_DIR = BASE_DIR / "model" / "ensembles"
CKPT_SHARED  = BASE_DIR / "model" / "checkpoints" / "shared" / "best.pt"
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Expect shared model & helpers from earlier cells:
assert CKPT_SHARED.exists(), "Shared best checkpoint not found. Run Phase 3 Cell 1 (and 1c optional) first."

# ------------------------------
# Utilities / reproducibility
# ------------------------------
def seed_everything(seed: int):
    import numpy as _np, random as _r, torch as _t
    _r.seed(seed); _np.random.seed(seed)
    _t.manual_seed(seed); _t.cuda.manual_seed_all(seed)

try:
    LABEL_NAMES
    v7_shared
except NameError:
    raise RuntimeError("Missing v7_shared or LABEL_NAMES in memory. Please re-run Phase 2 Cells 1–4 and Phase 3 Cell 1.")

# masked_mean fallback (if not in scope)
try:
    masked_mean
except NameError:
    def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
        mask = mask.to(dtype=x.dtype, device=x.device)
        denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
        return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

# ------------------------------
# Restore best shared weights & set eval
# ------------------------------
ckpt = torch.load(CKPT_SHARED, map_location=DEVICE)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()

# ------------------------------
# Fused cache helpers
# ------------------------------
@torch.no_grad()
def compute_fused_batch(smiles_list: List[str], desc_feats: torch.Tensor) -> torch.Tensor:
    tt, tm = v7_shared.text_encoder(smiles_list, max_length=256)
    gn, gm = v7_shared.graph_encoder(smiles_list, max_nodes=128)
    tt, tm = tt.to(DEVICE), tm.to(DEVICE)
    gn, gm = gn.to(DEVICE), gm.to(DEVICE)
    desc_feats = desc_feats.to(DEVICE)
    tta = v7_shared.cross(tt, tm, gn, gm)
    de  = v7_shared.desc_mlp(desc_feats)
    text_pool  = masked_mean(tta, tm, dim=1)
    graph_pool = masked_mean(gn,  gm, dim=1)
    fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B, 768)
    return fused

@torch.no_grad()
def cache_fused(npz_path: Path, out_prefix: str, batch_size: int = 256):
    blob = np.load(npz_path, allow_pickle=True)
    smiles = [str(s) for s in blob["smiles"].tolist()]
    Xd    = torch.tensor(blob["X"], dtype=torch.float32, device=DEVICE)
    Y     = blob["Y"].astype(np.float32)
    M     = blob["y_missing_mask"].astype(bool)
    molid = blob["mol_id"].tolist()

    fused_list = []
    for i in range(0, len(smiles), batch_size):
        fu = compute_fused_batch(smiles[i:i+batch_size], Xd[i:i+batch_size])
        fused_list.append(fu.cpu().numpy())
    F = np.concatenate(fused_list, axis=0).astype(np.float32)

    np.save(FUSED_DIR / f"{out_prefix}_fused.npy", F)
    np.save(FUSED_DIR / f"{out_prefix}_Y.npy",    Y)
    np.save(FUSED_DIR / f"{out_prefix}_mask.npy", M)
    np.save(FUSED_DIR / f"{out_prefix}_mol_id.npy", np.array(molid, dtype=object))
    print(f"Cached {out_prefix}: fused {F.shape}")

# Build/refresh caches if missing
if not (FUSED_DIR / "train_fused.npy").exists():
    cache_fused(DATA_PREP / "train.npz", "train", batch_size=256)
if not (FUSED_DIR / "val_fused.npy").exists():
    cache_fused(DATA_PREP / "val.npz", "val", batch_size=256)

# ------------------------------
# Load caches into memory
# ------------------------------
Xtr = np.load(FUSED_DIR / "train_fused.npy")
Ytr = np.load(FUSED_DIR / "train_Y.npy")
Mtr = np.load(FUSED_DIR / "train_mask.npy")
Xva = np.load(FUSED_DIR / "val_fused.npy")
Yva = np.load(FUSED_DIR / "val_Y.npy")
Mva = np.load(FUSED_DIR / "val_mask.npy")
print("Fused shapes → train:", Xtr.shape, "| val:", Xva.shape)

# ------------------------------
# Dataset & balanced sampler
# ------------------------------
class FusedLabelDataset(torch.utils.data.Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray, M: np.ndarray, j: int):
        valid = ~M[:, j]
        self.X = X[valid].astype(np.float32)
        self.y = Y[valid, j].astype(np.float32)
        assert self.X.shape[0] == self.y.shape[0]
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.tensor(self.y[i])

def make_balanced_sampler(y_np: np.ndarray):
    # Pos/neg weights so expected sampling is ~balanced
    pos = (y_np == 1).astype(np.float32)
    neg = (y_np == 0).astype(np.float32)
    n_pos = pos.sum(); n_neg = neg.sum()
    # Avoid zero-division; if n_pos==0, fall back to uniform
    if n_pos < 1:
        w = np.ones_like(y_np, dtype=np.float32)
    else:
        w_pos = 0.5 / max(n_pos, 1.0)
        w_neg = 0.5 / max(n_neg, 1.0)
        w = pos * w_pos + neg * w_neg
    return torch.DoubleTensor(w)

def make_loaders_for_label(j: int, bs: int = 1024):
    dtr = FusedLabelDataset(Xtr, Ytr, Mtr, j)
    dva = FusedLabelDataset(Xva, Yva, Mva, j)
    # Balanced sampler for training
    w = make_balanced_sampler(dtr.y)
    sampler = torch.utils.data.WeightedRandomSampler(weights=w, num_samples=len(dtr), replacement=True)
    train_loader = torch.utils.data.DataLoader(dtr, batch_size=bs, sampler=sampler,
                                               num_workers=0, pin_memory=(DEVICE.type=="cuda"))
    val_loader   = torch.utils.data.DataLoader(dva, batch_size=bs, shuffle=False,
                                               num_workers=0, pin_memory=(DEVICE.type=="cuda"))
    return train_loader, val_loader

# ------------------------------
# Loss (Binary ASL with per-label α via effective number)
# ------------------------------
def effective_alpha(pos_count: int, beta: float = 0.999) -> float:
    # Class-Balanced factor α for positives; scale to mean ~1 across labels later if desired.
    eff = (1 - (beta ** max(pos_count, 1))) / (1 - beta)
    return float(1.0 / max(eff, 1e-8))

class BinaryASL(nn.Module):
    def __init__(self, gamma_neg=5.0, gamma_pos=1.0, clip=0.05, alpha: float = 1.0):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.alpha = float(alpha)
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # logits/targets: (B,)
        p = torch.sigmoid(logits)
        if self.clip:
            p = torch.clamp(p, self.clip, 1 - self.clip)
        pos = targets
        neg = 1 - targets
        pt = p * pos + (1 - p) * neg
        gamma = self.gamma_pos * pos + self.gamma_neg * neg
        focal = torch.pow(1 - pt, gamma)
        loss = - (pos * torch.log(p + 1e-8) + neg * torch.log(1 - p + 1e-8))
        loss = loss * focal * self.alpha
        return loss.mean()

# ------------------------------
# Head model (stronger MLP + residual)
# ------------------------------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)  # residual path to help optimisation
    def forward(self, x):  # x: (B,768)
        z1 = self.block1(x)
        z2 = self.block2(z1)
        z3 = self.block3(z2)
        # residual from input
        z = z3 + self.short(x)
        return self.out(z).squeeze(-1)

# ------------------------------
# Scheduler: warmup + cosine
# ------------------------------
def build_warmup_cosine(total_steps, warmup_ratio=0.1, min_lr_scale=0.1):
    def lr_lambda(step):
        warm = int(total_steps * warmup_ratio)
        if step < warm:
            return float(step) / max(1, warm)
        progress = (step - warm) / max(1, total_steps - warm)
        cosine = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr_scale + (1 - min_lr_scale) * cosine
    return lr_lambda

# ------------------------------
# Train one seed for one label
# ------------------------------
def train_label_seed(label_name: str, j: int, seed: int,
                     epochs_max: int = 50, patience: int = 10,
                     lr: float = 3e-3, wd: float = 1e-2) -> Dict[str, float]:
    seed_everything(seed)

    # Loaders & stats
    train_loader, val_loader = make_loaders_for_label(j, bs=1024)
    pos_count = int((Ytr[~Mtr[:, j], j] == 1).sum())
    alpha = effective_alpha(pos_count)  # per-label CB factor

    model = LabelHead(in_dim=768, h1=512, h2=256, h3=128, p=0.30).to(DEVICE)
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    steps_per_epoch = len(train_loader)
    sched = LambdaLR(opt, build_warmup_cosine(total_steps=steps_per_epoch*epochs_max, warmup_ratio=0.1))

    criterion = BinaryASL(gamma_neg=5.0, gamma_pos=1.0, clip=0.05, alpha=alpha)
    scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type=="cuda"))
    ema_decay = 0.999
    ema = {n: p.detach().clone() for n,p in model.named_parameters() if p.requires_grad}

    def ema_update():
        for n,p in model.named_parameters():
            if p.requires_grad:
                ema[n].mul_(ema_decay).add_(p.detach(), alpha=1-ema_decay)

    def ema_copy_to():
        with torch.no_grad():
            for n,p in model.named_parameters():
                if p.requires_grad:
                    p.data.copy_(ema[n])

    # paths
    label_dir = ENSEMBLE_DIR / label_name / f"seed{seed:02d}"
    label_dir.mkdir(parents=True, exist_ok=True)
    log_path  = label_dir / "train_log.jsonl"
    if log_path.exists(): log_path.unlink()

    best_ap, best_path = -1.0, label_dir / "best.pt"
    wait = patience
    t_start = time.time()

    for epoch in range(1, epochs_max + 1):
        # ---- train
        model.train()
        epoch_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(DEVICE, non_blocking=True); yb = yb.to(DEVICE, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=(DEVICE.type=="cuda")):
                logits = model(xb)
                loss = criterion(logits, yb)
            scaler.scale(loss).backward()
            # grad clip
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            sched.step()
            epoch_loss += float(loss.item())
            # EMA
            ema_update()

        # ---- validate (with EMA weights)
        ema_copy_to()
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(DEVICE, non_blocking=True)
                logits = model(xb)
                p = torch.sigmoid(logits).detach().cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        try:
            ap = float(average_precision_score(gts, preds))
        except Exception:
            ap = float("nan")

        row = {"epoch": epoch, "train_loss": epoch_loss/max(1,len(train_loader)), "val_ap": ap, "time_min": round((time.time()-t_start)/60,2)}
        with open(log_path, "a") as f: f.write(json.dumps(row) + "\n")
        print(f"[{label_name} | seed {seed:02d}] ep {epoch:02d}  loss {row['train_loss']:.4f}  val AP {ap:.4f}")

        # save val preds for calibration (overwrite each epoch; best saved below)
        np.savez_compressed(label_dir / "val_preds.npz", preds=preds, y=gts)

        if np.isnan(ap):
            continue
        if ap > best_ap:
            best_ap = ap; wait = patience
            torch.save({"model": model.state_dict(),
                        "config": {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30},
                        "label": label_name, "seed": seed}, best_path)
            # also save best preds snapshot
            np.savez_compressed(label_dir / "val_preds_best.npz", preds=preds, y=gts)
            print(f"  ✅ New best AP: {best_ap:.4f} → {best_path}")
        else:
            wait -= 1
            if wait <= 0:
                print(f"  ⏹ Early stop (no improve for {patience} epochs). Best AP: {best_ap:.4f}")
                break

    # write seed summary
    with open(ENSEMBLE_DIR / label_name / f"seed{seed:02d}" / "metrics.json", "w") as f:
        f.write(json.dumps({"best_ap": best_ap, "epochs": epoch}, indent=2))
    return {"label": label_name, "seed": seed, "best_ap": best_ap}

# ------------------------------
# Driver: all labels, multi-seed
# ------------------------------
SEEDS = [13, 29, 47, 61, 83]    # 5 seeds (tweak as desired)
EPOCHS_MAX = 50
PATIENCE   = 10

summary = []
for j, name in enumerate(LABEL_NAMES):
    print("\n==============================")
    print(f"Training specialist heads for: {name} (label {j})")
    print("==============================")
    for seed in SEEDS:
        res = train_label_seed(name, j, seed, epochs_max=EPOCHS_MAX, patience=PATIENCE, lr=3e-3, wd=1e-2)
        summary.append(res)

# Aggregate best AP per label across seeds
agg = {}
for name in LABEL_NAMES:
    best = max((r for r in summary if r["label"] == name), key=lambda r: (r["best_ap"] if not math.isnan(r["best_ap"]) else -1.0))
    agg[name] = {"best_ap": best["best_ap"]}

(ENSEMBLE_DIR / "ensemble_summary.json").write_text(json.dumps({"per_seed": summary, "best_per_label": agg}, indent=2))
print("\n✅ Ensemble training complete.")
print(json.dumps({"best_per_label": agg}, indent=2))



Cached train: fused (6265, 768)
Cached val: fused (783, 768)
Fused shapes → train: (6265, 768) | val: (783, 768)

Training specialist heads for: NR-AR (label 0)
[NR-AR | seed 13] ep 01  loss 0.0007  val AP 0.0165
  ✅ New best AP: 0.0165 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 02  loss 0.0005  val AP 0.0169
  ✅ New best AP: 0.0169 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 03  loss 0.0005  val AP 0.0176
  ✅ New best AP: 0.0176 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 04  loss 0.0005  val AP 0.0187
  ✅ New best AP: 0.0187 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 05  loss 0.0006  val AP 0.0224
  ✅ New best AP: 0.0224 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 06  loss 0.0005  val AP 0.0935
  ✅ New best AP: 0.0935 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed 13] ep 07  loss 0.0005  val AP 0.1298
  ✅ New best AP: 0.1298 → v7\model\ensembles\NR-AR\seed13\best.pt
[NR-AR | seed

## phase 4 (Calibrate/Threshold)

### 1: Calibration (Temperature Scaling) + Thresholds (F1 / Fβ)

- Pick best specialist head per label (from `ensemble_summary.json`)
- Calibrate with per-label **temperature scaling** (on validation logits)
- Select **F1-max** and **Fβ=1.5-max** thresholds per label
- Save to `v7/model/calibration/` and curves to `v7/results/calibration/`


In [None]:
import os, json, math
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score

# ------------- Paths -------------
BASE_DIR    = Path("v7")
FUSED_DIR   = BASE_DIR / "data" / "fused"
ENS_DIR     = BASE_DIR / "model" / "ensembles"
CAL_DIR     = BASE_DIR / "model" / "calibration"
CAL_DIR.mkdir(parents=True, exist_ok=True)
CAL_RES_DIR = BASE_DIR / "results" / "calibration"
CAL_RES_DIR.mkdir(parents=True, exist_ok=True)

# Load label names from earlier phase
with open(BASE_DIR / "data" / "prepared" / "dataset_manifest.json") as f:
    ds_manifest = json.load(f)
LABEL_NAMES = ds_manifest["labels"]

# Load ensemble summary (decide best seed per label)
ens_summary = json.loads((ENS_DIR / "ensemble_summary.json").read_text())
best_per_label = ens_summary["best_per_label"]  # label -> {best_ap: ...}
print("Loaded ensemble summary for", len(best_per_label), "labels.")

# Fused validation cache
Xva = np.load(FUSED_DIR / "val_fused.npy")    # (N_val, 768)
Yva = np.load(FUSED_DIR / "val_Y.npy")        # (N_val, 12)
Mva = np.load(FUSED_DIR / "val_mask.npy")     # (N_val, 12) True where missing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Head definition (must match saved config) ---------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def load_best_seed_dir(label: str) -> Path:
    # Find the seed folder with highest best_ap for this label
    candidates = []
    for seed_dir in sorted((ENS_DIR / label).glob("seed*/")):
        mpath = seed_dir / "metrics.json"
        if mpath.exists():
            try:
                m = json.loads(mpath.read_text())
                candidates.append((float(m.get("best_ap", float("nan"))), seed_dir))
            except Exception:
                pass
    if not candidates:
        raise FileNotFoundError(f"No seed folders with metrics for label {label}")
    candidates.sort(key=lambda x: (x[0] if not math.isnan(x[0]) else -1.0), reverse=True)
    return candidates[0][1]  # best seed dir

@torch.no_grad()
def head_logits_on_val(label: str) -> Tuple[np.ndarray, np.ndarray]:
    """Return logits and ground truth for valid val rows of this label."""
    # Filter valid rows (not missing)
    j = LABEL_NAMES.index(label)
    valid = ~Mva[:, j]
    X = torch.tensor(Xva[valid], dtype=torch.float32, device=device)
    y = Yva[valid, j].astype(np.float32)

    # Load head
    best_dir = load_best_seed_dir(label)
    ckpt = torch.load(best_dir / "best.pt", map_location=device)
    cfg  = ckpt.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ckpt["model"], strict=True)
    head.eval()

    logits = []
    BS = 4096  # very fast on 4070 Ti
    for i in range(0, X.shape[0], BS):
        l = head(X[i:i+BS])
        logits.append(l.detach().cpu().numpy())
    logits = np.concatenate(logits, axis=0)  # (Nv,)
    return logits, y

def fit_temperature(logits: np.ndarray, y: np.ndarray, max_iter: int = 200, lr: float = 0.05) -> float:
    """
    Fit scalar temperature T>0, minimizing NLL on validation.
    """
    t = torch.tensor([1.0], dtype=torch.float32, requires_grad=True, device=device)
    x = torch.tensor(logits, dtype=torch.float32, device=device)
    y = torch.tensor(y,      dtype=torch.float32, device=device)
    opt = torch.optim.Adam([t], lr=lr)
    for _ in range(max_iter):
        opt.zero_grad(set_to_none=True)
        z = x / (t.clamp(min=1e-3))
        p = torch.sigmoid(z).clamp(1e-6, 1-1e-6)
        loss = - (y*torch.log(p) + (1-y)*torch.log(1-p)).mean()
        loss.backward()
        opt.step()
    return float(t.detach().cpu().item())

def best_thresholds(y_true: np.ndarray, probs: np.ndarray) -> Dict[str, float]:
    """
    Compute thresholds that maximize F1 and F-beta (beta=1.5) on validation.
    """
    if probs.ndim != 1: probs = probs.ravel()
    precision, recall, th = precision_recall_curve(y_true, probs)
    # PR curve returns len(th)+1 points; align F1/Fb on thresholds
    eps = 1e-8
    f1 = (2*precision*recall) / np.maximum(precision+recall, eps)
    beta = 1.5
    fb = ((1+beta**2)*precision*recall) / np.maximum((beta**2)*precision + recall, eps)

    # The first PR point has no threshold; we’ll map scores to thresholds array length
    f1_th  = th[np.nanargmax(f1[1:])]  if th.size>0 else 0.5
    fb_th  = th[np.nanargmax(fb[1:])]  if th.size>0 else 0.5
    # also report AP for reference
    try:
        ap = float(average_precision_score(y_true, probs))
    except Exception:
        ap = float("nan")
    return {"th_f1": float(f1_th), "th_fbeta15": float(fb_th), "ap_val": ap}

temps = {}
thresholds = {}

for label in LABEL_NAMES:
    print(f"\nCalibrating: {label}")
    logits, y = head_logits_on_val(label)
    if logits.size == 0 or np.all(y == y[0]):
        print("  ⚠️ Skipping (no variance or no valid rows). Using defaults.")
        temps[label] = 1.0
        thresholds[label] = {"th_f1": 0.5, "th_fbeta15": 0.5, "ap_val": float("nan")}
        continue

    T = fit_temperature(logits, y, max_iter=200, lr=0.05)
    probs_cal = 1.0 / (1.0 + np.exp(-logits / max(T, 1e-3)))

    th_dict = best_thresholds(y, probs_cal)
    temps[label] = T
    thresholds[label] = th_dict

    # optional: save PR curve arrays for later plotting/debug
    np.savez_compressed(CAL_RES_DIR / f"{label}_val_calib.npz", logits=logits, y=y, T=T, **th_dict)
    print(f"  T={T:.3f}  AP_val={th_dict['ap_val']:.4f}  th_f1={th_dict['th_f1']:.3f}  th_fβ1.5={th_dict['th_fbeta15']:.3f}")

# Save calibration artifacts
(Path(CAL_DIR / "temps.json")).write_text(json.dumps(temps, indent=2))
(Path(CAL_DIR / "thresholds.json")).write_text(json.dumps(thresholds, indent=2))
print("\n✅ Saved:")
print("  • temperatures →", CAL_DIR / "temps.json")
print("  • thresholds   →", CAL_DIR / "thresholds.json")

Loaded ensemble summary for 12 labels.

Calibrating: NR-AR
  T=4.549  AP_val=0.1770  th_f1=0.573  th_fβ1.5=0.573

Calibrating: NR-AR-LBD
  T=3.660  AP_val=0.3007  th_f1=0.610  th_fβ1.5=0.610

Calibrating: NR-AhR
  T=1.500  AP_val=0.5251  th_f1=0.681  th_fβ1.5=0.631

Calibrating: NR-Aromatase
  T=2.483  AP_val=0.2754  th_f1=0.604  th_fβ1.5=0.552

Calibrating: NR-ER
  T=4.657  AP_val=0.2309  th_f1=0.542  th_fβ1.5=0.523

Calibrating: NR-ER-LBD
  T=1.857  AP_val=0.1527  th_f1=0.650  th_fβ1.5=0.616

Calibrating: NR-PPAR-gamma
  T=4.558  AP_val=0.0918  th_f1=0.521  th_fβ1.5=0.521

Calibrating: SR-ARE
  T=4.500  AP_val=0.3360  th_f1=0.532  th_fβ1.5=0.532

Calibrating: SR-ATAD5
  T=4.475  AP_val=0.2277  th_f1=0.567  th_fβ1.5=0.556

Calibrating: SR-HSE
  T=4.639  AP_val=0.2323  th_f1=0.551  th_fβ1.5=0.545

Calibrating: SR-MMP
  T=1.896  AP_val=0.4462  th_f1=0.600  th_fβ1.5=0.600

Calibrating: SR-p53
  T=4.245  AP_val=0.2346  th_f1=0.546  th_fβ1.5=0.546

✅ Saved:
  • temperatures → v7\model\cali

## phase 5 (Inference)

### 1:  Inference (calibrated specialist ensemble) + test export

In [None]:
# === Cold-start Inference (checkpoint name-compatible) ===
import os, json, math
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn

# ---------------- Paths & basics ----------------
BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
DESC_DIR   = BASE / "data" / "descriptors"
MODEL_DIR  = BASE / "model"
CKPT_BEST  = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
ENS_DIR    = MODEL_DIR / "ensembles"
CAL_DIR    = MODEL_DIR / "calibration"

assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"
assert (PREP_DIR / "dataset_manifest.json").exists(), "Missing dataset manifest."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- Labels, temps, thresholds ----------------
ds_manifest = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABEL_NAMES = ds_manifest["labels"]
DESC_IN_DIM = ds_manifest["n_features"]  # 208

temps      = json.loads((CAL_DIR / "temps.json").read_text())
thresholds = json.loads((CAL_DIR / "thresholds.json").read_text())

# ---------------- Text encoder (ChemBERTa) ----------------
from transformers import AutoTokenizer, AutoModel

class ChemBERTaEncoder(nn.Module):
    def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)
        self.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
        self.ln = nn.LayerNorm(fusion_dim)
    def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
        enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                             max_length=max_length, add_special_tokens=add_special_tokens,
                             return_tensors="pt")
        input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
        toks = self.ln(self.proj(out))  # (B,L,256)
        return toks, attention_mask.to(dtype=torch.int32)

# ---------------- Graph encoder (match checkpoint names) ----------------
from rdkit import Chem as _Chem

ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]

def _one_hot(v, choices):
    z = [0]*len(choices)
    if v in choices:
        z[choices.index(v)] = 1
    return z

def _bucket_oh(v, lo, hi):
    buckets = list(range(lo, hi+1))
    o = [0]*(len(buckets)+1)
    idx = v - lo
    o[idx if 0 <= idx < len(buckets) else -1] = 1
    return o

def _atom_feat(atom):
    hybs = [
        _Chem.rdchem.HybridizationType.S, _Chem.rdchem.HybridizationType.SP,
        _Chem.rdchem.HybridizationType.SP2, _Chem.rdchem.HybridizationType.SP3,
        _Chem.rdchem.HybridizationType.SP3D, _Chem.rdchem.HybridizationType.SP3D2
    ]
    chir = [
        _Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        _Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        _Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        _Chem.rdchem.ChiralType.CHI_OTHER
    ]
    sym = atom.GetSymbol()
    feat = _one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
    feat += _bucket_oh(atom.GetDegree(), 0, 5)
    feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
    feat += (_one_hot(atom.GetHybridization(), hybs)+[0])  # +other
    feat += [int(atom.GetIsAromatic())]
    feat += [int(atom.IsInRing())]
    feat += _one_hot(atom.GetChiralTag(), chir)
    feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
    feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
    feat += [atom.GetMass()/200.0]
    return feat  # ~51 dims

def _smiles_to_graph(smi, max_nodes=128):
    mol = _Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms() == 0:
        return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)
    feats = [_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
    x = np.asarray(feats, dtype=np.float32)
    N = mol.GetNumAtoms()
    adj = np.zeros((N, N), dtype=np.float32)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        adj[i, j] = 1.0; adj[j, i] = 1.0
    if N > max_nodes:
        x = x[:max_nodes]; adj = adj[:max_nodes, :max_nodes]
    return x, adj

def _collate_graphs(smiles_batch, max_nodes=128):
    graphs = [_smiles_to_graph(s) for s in smiles_batch]
    Nmax = max([g[0].shape[0] for g in graphs] + [1])
    Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
    B = len(graphs)
    X = np.zeros((B, Nmax, Fnode), dtype=np.float32)
    A = np.zeros((B, Nmax, Nmax), dtype=np.float32)
    M = np.zeros((B, Nmax), dtype=np.int64)
    for i, (x, a) in enumerate(graphs):
        n = x.shape[0]
        if n == 0: continue
        X[i, :n, :] = x
        A[i, :n, :n] = a
        M[i, :n] = 1
    return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

class GINLayer(nn.Module):
    def __init__(self, h=256, p=0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(0.0))
        self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
    def forward(self, x, adj, mask):
        out = (1.0 + self.eps) * x + torch.matmul(adj, x)
        out = self.mlp(out)
        return out * mask.unsqueeze(-1).to(out.dtype)

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
        self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
        # IMPORTANT: name must be 'out_ln' to match checkpoint
        self.out_ln = nn.LayerNorm(hidden_dim)
    def forward(self, smiles_list: List[str], max_nodes=128):
        X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
        h = self.inp(X)
        for layer in self.layers:
            h = layer(h, A, M)
        return self.out_ln(h), M.to(dtype=torch.int32)

# ---------------- Fusion blocks ----------------
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
    return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=256, n_heads=4, p=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
        self.ln  = nn.LayerNorm(dim)
        self.do  = nn.Dropout(p)
    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        Q = text_tokens.transpose(0,1)   # (L,B,D)
        K = graph_nodes.transpose(0,1)   # (N,B,D)
        V = graph_nodes.transpose(0,1)
        kpm = (graph_mask == 0)          # (B,N) True where pad
        attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
        attn = attn.transpose(0,1)       # (B,L,D)
        return self.ln(text_tokens + self.do(attn))

class DescriptorMLP(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
            nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)

# IMPORTANT: name must be 'mlp' to match checkpoint ('shared_head.mlp.*')
class FusionClassifier(nn.Module):
    def __init__(self, dim=256, n_labels=12, p=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, fused_vec):
        return self.mlp(fused_vec)

class V7FusionModel(nn.Module):
    def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
        super().__init__()
        self.text_encoder=text_encoder
        self.graph_encoder=graph_encoder
        self.cross=CrossAttentionBlock(dim, n_heads, p)
        self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
        self.shared_head=FusionClassifier(dim, n_labels, p)
    def forward(self, smiles_list, desc_feats, return_intermediates=False):
        tt, tm = self.text_encoder(smiles_list, max_length=256)
        gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
        tt, tm, gn, gm, desc_feats = tt.to(device), tm.to(device), gn.to(device), gm.to(device), desc_feats.to(device)
        tta = self.cross(tt, tm, gn, gm)
        de  = self.desc_mlp(desc_feats)
        text_pool  = masked_mean(tta, tm, 1)
        graph_pool = masked_mean(gn,  gm, 1)
        fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
        logits = self.shared_head(fused)
        if return_intermediates:
            return logits, fused
        return logits

# ---------------- Build & load ----------------
text_encoder = ChemBERTaEncoder().to(device)
graph_encoder= GraphGINEncoder().to(device)
v7_shared    = V7FusionModel(text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, n_labels=len(LABEL_NAMES)).to(device)
ckpt = torch.load(CKPT_BEST, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()
print("✅ Loaded shared fusion model.")

# ---------------- Specialist heads (match boosted Cell 2) ----------------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _load_best_head(label: str) -> nn.Module:
    # pick seed dir with highest best_ap
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ck["model"], strict=True)
    head.eval()
    return head

HEADS: Dict[str, nn.Module] = {lbl: _load_best_head(lbl) for lbl in LABEL_NAMES}
print("✅ Loaded specialist heads for all labels.")

# ---------------- Descriptors for ad-hoc SMILES ----------------
# For quick testing without the exact 208-d extractor, use standardized zero vector for descriptors.
def prepare_desc_matrix(smiles_list: List[str]) -> torch.Tensor:
    n = len(smiles_list)
    Z = np.zeros((n, DESC_IN_DIM), dtype=np.float32)  # standardized zeros (mean feature)
    return torch.tensor(Z, dtype=torch.float32, device=device)

# ---------------- Fused feature builder ----------------
@torch.no_grad()
def fused_from_smiles(smiles_list: List[str], desc_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
    if desc_tensor is None:
        desc_tensor = prepare_desc_matrix(smiles_list)
    tt, tm = v7_shared.text_encoder(smiles_list, max_length=256)
    gn, gm = v7_shared.graph_encoder(smiles_list, max_nodes=128)
    tt, tm = tt.to(device), tm.to(device)
    gn, gm = gn.to(device), gm.to(device)
    de = v7_shared.desc_mlp(desc_tensor.to(device))
    # cross-attend & pool
    tta = v7_shared.cross(tt, tm, gn, gm)
    text_pool  = masked_mean(tta, tm, 1)
    graph_pool = masked_mean(gn,  gm, 1)
    return torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)

# ---------------- Public API ----------------
def predict_smiles(smiles_list: List[str], threshold_mode: str = "fbeta15"):
    """
    Returns list[dict]: one per SMILES
      label -> {logit, prob_raw, prob_cal, decision}
    """
    assert threshold_mode in ("f1", "fbeta15")
    fused = fused_from_smiles(smiles_list)  # (B,768)
    out = []
    for i in range(fused.size(0)):
        row = {}
        x = fused[i:i+1]
        for label in LABEL_NAMES:
            head = HEADS[label]
            with torch.no_grad():
                logit = head(x).item()
            T   = max(float(temps.get(label, 1.0)), 1e-3)
            p_r = 1.0 / (1.0 + math.e**(-logit))
            p_c = 1.0 / (1.0 + math.e**(-logit / T))
            th  = thresholds[label]["th_fbeta15"] if threshold_mode=="fbeta15" else thresholds[label]["th_f1"]
            row[label] = {"logit": float(logit), "prob_raw": float(p_r), "prob_cal": float(p_c), "decision": bool(p_c >= float(th))}
        out.append(row)
    return out

print("✅ Inference is ready: call predict_smiles(['CCO'], threshold_mode='fbeta15' or 'f1').")

✅ Loaded shared fusion model.
✅ Loaded specialist heads for all labels.
✅ Inference is ready: call predict_smiles(['CCO'], threshold_mode='fbeta15' or 'f1').


In [None]:
# my_smiles = ["CCOc1ccc2nc(S(N)(=O)=O)sc2c1"]
# mode = "f1"  # or "f1" fbeta15

# results = predict_smiles(my_smiles, threshold_mode=mode)

# from operator import itemgetter
# for smi, rec in zip(my_smiles, results):
#     print("\nSMILES:", smi)
#     top = sorted([(lbl, d["prob_cal"], d["decision"]) for lbl, d in rec.items()],
#                  key=itemgetter(1), reverse=True)[:5]
#     for lbl, p, dec in top:
#         th = thresholds[lbl]["th_fbeta15"] if mode=="fbeta15" else thresholds[lbl]["th_f1"]
#         print(f"  {lbl:12s}  prob={p:.3f}  th={th:.3f}  → pred={int(dec)}")




# Ad-hoc evaluation on Excel truth labels (simple)
import pandas as pd
import numpy as np
from pathlib import Path
from operator import itemgetter
import math, json, os

# ----------- CONFIG -----------
EXCEL_PATH = Path("tox21_dualenc_v1/data/raw/Truth Lables.xlsx")
MODE = "f1"            # "f1" or "fbeta15"
N_DISPLAY = 5          # how many rows to pretty-print (set to None to print all)
OUT_CSV = Path("v7/results/inference/f1.csv")
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)

# ----------- Checks -----------
assert 'predict_smiles' in globals(), "predict_smiles() not found. Run the cold-start inference cell first."
assert 'LABEL_NAMES' in globals(), "LABEL_NAMES not found. Run the cold-start inference cell first."
assert 'thresholds' in globals(), "thresholds not found. Run Phase 4 calibration cell first."
assert EXCEL_PATH.exists(), f"Cannot find: {EXCEL_PATH}"

# ----------- Load Excel -----------
df = pd.read_excel(EXCEL_PATH)
cols_lower = {c.lower(): c for c in df.columns}
# find smiles col (case-insensitive)
smiles_col = None
for key in ["smiles", "smile", "SMILES", "Smiles"]:
    if key.lower() in cols_lower:
        smiles_col = cols_lower[key.lower()]
        break
if smiles_col is None:
    # fallback: first column named like 'smile*'
    cand = [c for c in df.columns if c.lower().startswith("smiles")]
    smiles_col = cand[0] if cand else None
assert smiles_col is not None, "Could not locate a SMILES column in the Excel file."

# ----------- Match label columns (case/spacing/hyphen-insensitive) -----------
def _norm(s: str) -> str:
    return "".join(ch for ch in str(s).lower() if ch.isalnum())

label_norm = { _norm(lbl): lbl for lbl in LABEL_NAMES }
col_for_label = {}  # label -> column name in df (if present)

for col in df.columns:
    if col == smiles_col: 
        continue
    n = _norm(col)
    if n in label_norm:
        col_for_label[label_norm[n]] = col

available_labels = [lbl for lbl in LABEL_NAMES if lbl in col_for_label]
missing_labels = [lbl for lbl in LABEL_NAMES if lbl not in col_for_label]
print(f"Found {len(available_labels)}/{len(LABEL_NAMES)} label columns in the Excel.")
if missing_labels:
    print("Missing label columns (will be skipped in scoring):", ", ".join(missing_labels))

# ----------- Parse truth values -----------
def parse_truth(v):
    if pd.isna(v): 
        return None
    if isinstance(v, (int, np.integer)): 
        return int(v) == 1
    if isinstance(v, float): 
        if math.isnan(v): return None
        return int(v) == 1
    s = str(v).strip().lower()
    if s in ("1","y","yes","true","t","pos","positive"):
        return True
    if s in ("0","n","no","false","f","neg","negative"):
        return False
    # anything else → None (unknown)
    return None

# ----------- Run predictions -----------
smiles_list = df[smiles_col].astype(str).tolist()
preds = predict_smiles(smiles_list, threshold_mode=MODE)  # list[dict[label -> details]]

# ----------- Build a simple evaluation table -----------
rows = []
micro_tp = micro_fp = micro_fn = 0

for i, (smi, rec) in enumerate(zip(smiles_list, preds)):
    # truth set (only for labels available in Excel)
    true_pos = set()
    true_neg = set()
    for lbl in available_labels:
        val = parse_truth(df.loc[i, col_for_label[lbl]])
        if val is True:
            true_pos.add(lbl)
        elif val is False:
            true_neg.add(lbl)
        # None → skip

    # predicted positives at chosen threshold
    pred_pos = {lbl for lbl, d in rec.items() if d["decision"]}
    # accumulate micro counts only on labels where truth is known
    for lbl in available_labels:
        val = parse_truth(df.loc[i, col_for_label[lbl]])
        if val is None: 
            continue
        if lbl in pred_pos and val is True:
            micro_tp += 1
        elif lbl in pred_pos and val is False:
            micro_fp += 1
        elif lbl not in pred_pos and val is True:
            micro_fn += 1

    # top-5 by calibrated probability (for pretty print)
    top5 = sorted([(lbl, d["prob_cal"], d["decision"]) for lbl, d in rec.items()],
                  key=itemgetter(1), reverse=True)[:5]

    # save a row for CSV: include probs & preds, and truths if present
    row = {"smiles": smi}
    for lbl, det in rec.items():
        row[f"{lbl}_prob"] = det["prob_cal"]
        row[f"{lbl}_pred"] = int(det["decision"])
        if lbl in available_labels:
            tv = parse_truth(df.loc[i, col_for_label[lbl]])
            row[f"{lbl}_true"] = (None if tv is None else int(tv))
    rows.append(row)

    # pretty print a few rows
    if N_DISPLAY is None or i < N_DISPLAY:
        print("\nSMILES:", smi)
        for lbl, p, dec in top5:
            th = thresholds[lbl]["th_fbeta15"] if MODE=="fbeta15" else thresholds[lbl]["th_f1"]
            print(f"  {lbl:12s}  prob={p:.3f}  th={float(th):.3f}  → pred={int(dec)}")
        if available_labels:
            print("  True positives:", ", ".join(sorted(true_pos)) if true_pos else "—")
            chosen = ", ".join(sorted(pred_pos)) if pred_pos else "—"
            print(f"  Pred positives ({MODE}): {chosen}")

# ----------- Micro summary -----------
prec = micro_tp / (micro_tp + micro_fp) if (micro_tp + micro_fp) > 0 else 0.0
rec  = micro_tp / (micro_tp + micro_fn) if (micro_tp + micro_fn) > 0 else 0.0
f1   = (2*prec*rec)/(prec+rec) if (prec+rec) > 0 else 0.0

print("\n=== Summary (micro over labels with truth present) ===")
print(f"TP={micro_tp} FP={micro_fp} FN={micro_fn}")
print(f"Precision={prec:.3f} Recall={rec:.3f} F1={f1:.3f}")

# ----------- Save CSV -----------
pd.DataFrame(rows).to_csv(OUT_CSV, index=False)
print(f"\nSaved detailed results → {OUT_CSV}")


Found 12/12 label columns in the Excel.

SMILES: CCOc1ccc2nc(S(N)(=O)=O)sc2c1
  NR-AhR        prob=0.594  th=0.681  → pred=0
  SR-ARE        prob=0.533  th=0.532  → pred=1
  NR-ER         prob=0.529  th=0.542  → pred=0
  SR-ATAD5      prob=0.529  th=0.567  → pred=0
  NR-PPAR-gamma  prob=0.521  th=0.521  → pred=1
  True positives: NR-AhR, SR-ARE
  Pred positives (f1): NR-PPAR-gamma, SR-ARE

SMILES: CCN1C(=O)NC(c2ccccc2)C1=O
  NR-AhR        prob=0.616  th=0.681  → pred=0
  SR-MMP        prob=0.554  th=0.600  → pred=0
  SR-ATAD5      prob=0.537  th=0.567  → pred=0
  NR-PPAR-gamma  prob=0.536  th=0.521  → pred=1
  SR-p53        prob=0.534  th=0.546  → pred=0
  True positives: —
  Pred positives (f1): NR-PPAR-gamma

SMILES: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1
  NR-AhR        prob=0.670  th=0.681  → pred=0
  SR-MMP        prob=0.641  th=0.600  → pred=1
  NR-ER-LBD     prob=0.580  th=0.650  → pred=0
  NR-Aromatase  prob=0.569  th=0.604  → pred=0
  SR-p53        prob=0.560  th=0.546  → pred

### 2: Calibrate shared head, create blended ensemble, refit thresholds

In [None]:
# Phase 5 — Cell 2 (optional): shared+specialist blend with calibration and new thresholds
import json, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score

BASE      = Path("v7")
FUSED_DIR = BASE / "data" / "fused"
CAL_DIR   = BASE / "model" / "calibration"
ENS_DIR   = BASE / "model" / "ensembles"
CAL_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Expect these in memory from earlier cold-start cell:
# v7_shared (with .shared_head), HEADS (specialists), LABEL_NAMES, temps (specialist temps)
assert 'v7_shared' in globals() and 'HEADS' in globals() and 'LABEL_NAMES' in globals() and 'temps' in globals()

# ---- load val fused + labels/mask ----
Xva = np.load(FUSED_DIR / "val_fused.npy")     # (N,768)
Yva = np.load(FUSED_DIR / "val_Y.npy")         # (N,12)
Mva = np.load(FUSED_DIR / "val_mask.npy")      # (N,12) True where missing

Xva_t = torch.tensor(Xva, dtype=torch.float32, device=device)

# ---- helper: fit per-label temperature (on logits) ----
def fit_temperature(logits: np.ndarray, y: np.ndarray, max_iter=200, lr=0.05) -> float:
    t = torch.tensor([1.0], dtype=torch.float32, requires_grad=True, device=device)
    x = torch.tensor(logits, dtype=torch.float32, device=device)
    y = torch.tensor(y,      dtype=torch.float32, device=device)
    opt = torch.optim.Adam([t], lr=lr)
    for _ in range(max_iter):
        opt.zero_grad(set_to_none=True)
        z = x / (t.clamp(min=1e-3))
        p = torch.sigmoid(z).clamp(1e-6, 1-1e-6)
        loss = - (y*torch.log(p) + (1-y)*torch.log(1-p)).mean()
        loss.backward(); opt.step()
    return float(t.detach().cpu().item())

def best_thresholds(y_true: np.ndarray, probs: np.ndarray):
    prec, rec, th = precision_recall_curve(y_true, probs)
    eps = 1e-8
    f1 = (2*prec*rec) / np.maximum(prec+rec, eps)
    beta = 1.5
    fb = ((1+beta**2)*prec*rec) / np.maximum((beta**2)*prec + rec, eps)
    th_f1 = th[np.nanargmax(f1[1:])] if th.size>0 else 0.5
    th_fb = th[np.nanargmax(fb[1:])] if th.size>0 else 0.5
    try:
        ap = float(average_precision_score(y_true, probs))
    except Exception:
        ap = float("nan")
    return {"th_f1": float(th_f1), "th_fbeta15": float(th_fb), "ap_val": ap}

# ---- 1) Calibrate SHARED head per label on val ----
print("Calibrating shared head temperatures on val...")
logits_shared = v7_shared.shared_head(Xva_t).detach().cpu().numpy()  # (N,12)
temps_shared = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mva[:, j]
    if valid.sum() == 0 or np.all(Yva[valid, j] == Yva[valid, j][0]):
        temps_shared[lbl] = 1.0
        continue
    T = fit_temperature(logits_shared[valid, j], Yva[valid, j])
    temps_shared[lbl] = T
    print(f"  {lbl}: T_shared={T:.3f}")
(Path(CAL_DIR / "temps_shared.json")).write_text(json.dumps(temps_shared, indent=2))
print("Saved →", CAL_DIR / "temps_shared.json")

# ---- 2) Build BLENDED probs on val (alpha specialist, (1-alpha) shared) ----
ALPHA = 0.8  # weight on specialist; tweak if desired
print(f"\nBlending probs on val with alpha={ALPHA:.2f} (specialist weight)")

# specialist logits on val
spec_logits = np.zeros_like(logits_shared)
with torch.no_grad():
    for j, lbl in enumerate(LABEL_NAMES):
        head = HEADS[lbl]
        spec_logits[:, j] = head(Xva_t).detach().cpu().numpy()

# calibrate both streams
p_spec_val   = np.zeros_like(spec_logits)
p_shared_val = np.zeros_like(logits_shared)
for j, lbl in enumerate(LABEL_NAMES):
    T_spec   = max(float(temps.get(lbl, 1.0)), 1e-3)
    T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
    p_spec_val[:, j]   = 1. / (1. + np.exp(-spec_logits[:, j]   / T_spec))
    p_shared_val[:, j] = 1. / (1. + np.exp(-logits_shared[:, j] / T_shared))

p_blend_val = ALPHA * p_spec_val + (1-ALPHA) * p_shared_val
p_blend_val = np.clip(p_blend_val, 0.0, 1.0)

# ---- 3) Refit thresholds for BLEND on val ----
thresholds_blend = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mva[:, j]
    if valid.sum() == 0 or np.all(Yva[valid, j] == Yva[valid, j][0]):
        thresholds_blend[lbl] = {"th_f1": 0.5, "th_fbeta15": 0.5, "ap_val": float("nan")}
        continue
    thresholds_blend[lbl] = best_thresholds(Yva[valid, j], p_blend_val[valid, j])
    print(f"  {lbl}: AP_val={thresholds_blend[lbl]['ap_val']:.3f} th_f1={thresholds_blend[lbl]['th_f1']:.3f} th_fb15={thresholds_blend[lbl]['th_fbeta15']:.3f}")

(Path(CAL_DIR / "thresholds_blend.json")).write_text(json.dumps({
    "alpha": ALPHA,
    "thresholds": thresholds_blend
}, indent=2))
print("\nSaved →", CAL_DIR / "thresholds_blend.json")

# ---- 4) Provide a convenience predictor using the BLEND (keep specialist predictor unchanged) ----
def predict_smiles_blend(smiles_list, mode: str = "fbeta15", alpha: float = ALPHA):
    """
    Returns list[dict]: per SMILES -> label -> {prob_spec, prob_shared, prob_blend, decision}
    """
    assert mode in ("f1","fbeta15")
    # fused features from shared encoders (desc branch is already wired)
    fused = fused_from_smiles(smiles_list)  # (B,768)
    out = []
    X = fused  # torch Tensor
    with torch.no_grad():
        logits_shared = v7_shared.shared_head(X).detach().cpu().numpy()
    for i in range(X.size(0)):
        row = {}
        xi = X[i:i+1]
        for j, lbl in enumerate(LABEL_NAMES):
            # specialist
            with torch.no_grad():
                logit_spec = HEADS[lbl](xi).item()
            T_spec   = max(float(temps.get(lbl, 1.0)), 1e-3)
            p_spec   = 1. / (1. + math.e**(-logit_spec / T_spec))
            # shared
            T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
            logit_sh = logits_shared[i, j]
            p_shared = 1. / (1. + math.e**(-logit_sh   / T_shared))
            # blend
            p_blend = alpha * p_spec + (1-alpha) * p_shared
            # threshold (use blended thresholds we just computed)
            th = thresholds_blend[lbl]["th_fbeta15"] if mode=="fbeta15" else thresholds_blend[lbl]["th_f1"]
            row[lbl] = {
                "prob_spec": float(p_spec),
                "prob_shared": float(p_shared),
                "prob_blend": float(p_blend),
                "decision": bool(p_blend >= float(th)),
            }
        out.append(row)
    return out

print("\n✅ Blend ready: use predict_smiles_blend([...], mode='fbeta15' or 'f1').")

Calibrating shared head temperatures on val...
  NR-AR: T_shared=0.134
  NR-AR-LBD: T_shared=0.132
  NR-AhR: T_shared=0.167
  NR-Aromatase: T_shared=0.126
  NR-ER: T_shared=0.134
  NR-ER-LBD: T_shared=0.110
  NR-PPAR-gamma: T_shared=0.167
  SR-ARE: T_shared=0.260
  SR-ATAD5: T_shared=0.146
  SR-HSE: T_shared=0.100
  SR-MMP: T_shared=0.250
  SR-p53: T_shared=0.119
Saved → v7\model\calibration\temps_shared.json

Blending probs on val with alpha=0.80 (specialist weight)
  NR-AR: AP_val=0.171 th_f1=0.653 th_fb15=0.653
  NR-AR-LBD: AP_val=0.253 th_f1=0.621 th_fb15=0.621
  NR-AhR: AP_val=0.524 th_f1=0.709 th_fb15=0.642
  NR-Aromatase: AP_val=0.295 th_f1=0.564 th_fb15=0.474
  NR-ER: AP_val=0.253 th_f1=0.547 th_fb15=0.480
  NR-ER-LBD: AP_val=0.139 th_f1=0.589 th_fb15=0.589
  NR-PPAR-gamma: AP_val=0.063 th_f1=0.441 th_fb15=0.427
  SR-ARE: AP_val=0.344 th_f1=0.528 th_fb15=0.528
  SR-ATAD5: AP_val=0.171 th_f1=0.483 th_fb15=0.483
  SR-HSE: AP_val=0.196 th_f1=0.472 th_fb15=0.459
  SR-MMP: AP_val=0.

### 3: Evaluate on test set & export CSV (choose specialist or blend)

In [None]:
# Phase 5 — Cell 3: Test export + quick metrics
import json, math
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.metrics import average_precision_score, precision_recall_curve

BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
FUSED_DIR  = BASE / "data" / "fused"
RESULTS_DIR= BASE / "results" / "inference"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
CAL_DIR    = BASE / "model" / "calibration"

# Choose which predictor to use:
USE_BLEND = True     # True → use predict_smiles_blend; False → use specialist-only predict_smiles
MODE      = "fbeta15"  # "fbeta15" or "f1"

# Load test blobs
blob = np.load(PREP_DIR / "test.npz", allow_pickle=True)
smiles = [str(s) for s in blob["smiles"].tolist()]
Yte    = blob["Y"].astype(np.float32)
Mte    = blob["y_missing_mask"].astype(bool)

# Also load fused for test to speed shared head for blend
Xte_fused = np.load(FUSED_DIR / "test_fused.npy") if (FUSED_DIR / "test_fused.npy").exists() else None

# Ensure thresholds for selected path
if USE_BLEND:
    data = json.loads((CAL_DIR / "thresholds_blend.json").read_text())
    thresholds_blend = data["thresholds"]
else:
    thresholds_spec = json.loads((CAL_DIR / "thresholds.json").read_text())

rows = []
probs_mat = np.zeros((len(smiles), len(LABEL_NAMES)), dtype=np.float32)

if USE_BLEND:
    # Compute via blend predictor
    preds = predict_smiles_blend(smiles, mode=MODE)
    for i, (smi, rec) in enumerate(zip(smiles, preds)):
        row = {"smiles": smi}
        for j, lbl in enumerate(LABEL_NAMES):
            p = rec[lbl]["prob_blend"]
            d = int(rec[lbl]["decision"])
            row[f"{lbl}_prob"] = p
            row[f"{lbl}_pred"] = d
            probs_mat[i, j] = p
        rows.append(row)
    out_csv = RESULTS_DIR / f"predictions_test_blend_{MODE}.csv"
else:
    # Specialist-only
    preds = predict_smiles(smiles, threshold_mode=MODE)
    for i, (smi, rec) in enumerate(zip(smiles, preds)):
        row = {"smiles": smi}
        for j, lbl in enumerate(LABEL_NAMES):
            p = rec[lbl]["prob_cal"]
            d = int(rec[lbl]["decision"])
            row[f"{lbl}_prob"] = p
            row[f"{lbl}_pred"] = d
            probs_mat[i, j] = p
        rows.append(row)
    out_csv = RESULTS_DIR / f"predictions_test_specialist_{MODE}.csv"

pd.DataFrame(rows).to_csv(out_csv, index=False)
print("✅ Saved:", out_csv)

# ---- Tiny metrics (test) ----
per_label_ap = {}
for j, lbl in enumerate(LABEL_NAMES):
    valid = ~Mte[:, j]
    if valid.sum() == 0 or np.all(Yte[valid, j] == Yte[valid, j][0]):
        per_label_ap[lbl] = float("nan"); continue
    try:
        per_label_ap[lbl] = float(average_precision_score(Yte[valid, j], probs_mat[valid, j]))
    except Exception:
        per_label_ap[lbl] = float("nan")

macro_pr = float(np.nanmean([v for v in per_label_ap.values()]))

# micro P/R/F1 using chosen thresholds
tp = fp = fn = 0
for i in range(len(smiles)):
    for j, lbl in enumerate(LABEL_NAMES):
        if Mte[i, j]: 
            continue
        truth = int(Yte[i, j])
        pred  = rows[i][f"{lbl}_pred"]
        tp += int(pred == 1 and truth == 1)
        fp += int(pred == 1 and truth == 0)
        fn += int(pred == 0 and truth == 1)

prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0

report = {
    "mode": ("blend" if USE_BLEND else "specialist"),
    "threshold_mode": MODE,
    "macro_pr_auc": macro_pr,
    "micro_precision": prec,
    "micro_recall": rec,
    "micro_f1": f1,
    "per_label_ap": per_label_ap
}
report_path = RESULTS_DIR / f"test_report_{'blend' if USE_BLEND else 'specialist'}_{MODE}.json"
report_path.write_text(json.dumps(report, indent=2))
print("\nSummary (test):")
print(json.dumps({k: (round(v,4) if isinstance(v, float) else v) for k,v in report.items() if k!='per_label_ap'}, indent=2))
print("Per-label AP saved in report JSON.")

✅ Saved: v7\results\inference\predictions_test_blend_fbeta15.csv

Summary (test):
{
  "mode": "blend",
  "threshold_mode": "fbeta15",
  "macro_pr_auc": 0.3208,
  "micro_precision": 0.2079,
  "micro_recall": 0.5734,
  "micro_f1": 0.3052
}
Per-label AP saved in report JSON.


### 4: test reg after cell 2& 3 (gave very strong results!)

In [None]:
# === V7: Single-SMILES/SMARTS Test Rig (BLENDED: specialist + shared) ===
# Uses:
#   v7/model/checkpoints/shared/best.pt
#   v7/model/ensembles/<label>/seed*/best.pt
#   v7/model/calibration/temps.json           (specialist temps)
#   v7/model/calibration/temps_shared.json    (shared temps)
#   v7/model/calibration/thresholds_blend.json (alpha + per-label thresholds)

import os, json, math
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn

BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
DESC_DIR   = BASE / "data" / "descriptors"
MODEL_DIR  = BASE / "model"
CKPT_BEST  = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
ENS_DIR    = MODEL_DIR / "ensembles"
CAL_DIR    = MODEL_DIR / "calibration"

assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"
assert (PREP_DIR / "dataset_manifest.json").exists(), "Missing dataset manifest."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Labels & calibration artifacts ---
ds_manifest = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABEL_NAMES: List[str] = ds_manifest["labels"]
DESC_IN_DIM = ds_manifest["n_features"]  # 208

temps_spec    = json.loads((CAL_DIR / "temps.json").read_text())           # specialist
temps_shared  = json.loads((CAL_DIR / "temps_shared.json").read_text())    # shared
blend_payload = json.loads((CAL_DIR / "thresholds_blend.json").read_text())
ALPHA         = float(blend_payload.get("alpha", 0.8))
thr_blend     = blend_payload["thresholds"]  # label -> {th_f1, th_fbeta15, ap_val}

# --- Text encoder (ChemBERTa) ---
from transformers import AutoTokenizer, AutoModel
class ChemBERTaEncoder(nn.Module):
    def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)
        self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
        self.ln = nn.LayerNorm(fusion_dim)
    def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
        enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                             max_length=max_length, add_special_tokens=add_special_tokens,
                             return_tensors="pt")
        input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
        toks = self.ln(self.proj(out))  # (B,L,256)
        return toks, attention_mask.to(dtype=torch.int32)

# --- Graph encoder (names matched to checkpoint) ---
from rdkit import Chem
ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]

def _one_hot(v, choices):
    z = [0]*len(choices)
    if v in choices: z[choices.index(v)] = 1
    return z

def _bucket_oh(v, lo, hi):
    buckets = list(range(lo, hi+1))
    o = [0]*(len(buckets)+1)
    idx = v - lo
    o[idx if 0 <= idx < len(buckets) else -1] = 1
    return o

def _atom_feat(atom):
    hybs = [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
    chir = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER]
    sym = atom.GetSymbol()
    feat = _one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
    feat += _bucket_oh(atom.GetDegree(), 0, 5)
    feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
    feat += (_one_hot(atom.GetHybridization(), hybs)+[0])
    feat += [int(atom.GetIsAromatic())]
    feat += [int(atom.IsInRing())]
    feat += _one_hot(atom.GetChiralTag(), chir)
    feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
    feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
    feat += [atom.GetMass()/200.0]
    return feat  # ~51 dims

def _smiles_to_graph(smi, max_nodes=128):
    mol = Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms() == 0:
        return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)
    feats = [_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
    x = np.asarray(feats, dtype=np.float32)
    N = mol.GetNumAtoms()
    adj = np.zeros((N, N), dtype=np.float32)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        adj[i, j] = 1.0; adj[j, i] = 1.0
    if N > max_nodes:
        x = x[:max_nodes]; adj = adj[:max_nodes, :max_nodes]
    return x, adj

def _collate_graphs(smiles_batch, max_nodes=128):
    graphs = [_smiles_to_graph(s) for s in smiles_batch]
    Nmax = max([g[0].shape[0] for g in graphs] + [1])
    Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
    B = len(graphs)
    X = np.zeros((B, Nmax, Fnode), dtype=np.float32)
    A = np.zeros((B, Nmax, Nmax), dtype=np.float32)
    M = np.zeros((B, Nmax), dtype=np.int64)
    for i, (x, a) in enumerate(graphs):
        n = x.shape[0]
        if n == 0: continue
        X[i, :n, :] = x
        A[i, :n, :n] = a
        M[i, :n] = 1
    return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

class GINLayer(nn.Module):
    def __init__(self, h=256, p=0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(0.0))
        self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
    def forward(self, x, adj, mask):
        out = (1.0 + self.eps) * x + torch.matmul(adj, x)
        out = self.mlp(out)
        return out * mask.unsqueeze(-1).to(out.dtype)

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
        self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
        self.out_ln = nn.LayerNorm(hidden_dim)  # name matches checkpoint
    def forward(self, smiles_list: List[str], max_nodes=128):
        X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
        h = self.inp(X)
        for layer in self.layers:
            h = layer(h, A, M)
        return self.out_ln(h), M.to(dtype=torch.int32)

# --- Fusion & heads ---
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
    return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=256, n_heads=4, p=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
        self.ln  = nn.LayerNorm(dim)
        self.do  = nn.Dropout(p)
    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        Q = text_tokens.transpose(0,1)   # (L,B,D)
        K = graph_nodes.transpose(0,1)   # (N,B,D)
        V = graph_nodes.transpose(0,1)
        kpm = (graph_mask == 0)          # (B,N)
        attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
        attn = attn.transpose(0,1)       # (B,L,D)
        return self.ln(text_tokens + self.do(attn))

class DescriptorMLP(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
            nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)

class FusionClassifier(nn.Module):
    # name 'mlp' matches checkpoint ('shared_head.mlp.*')
    def __init__(self, dim=256, n_labels=12, p=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, fused_vec): return self.mlp(fused_vec)

class V7FusionModel(nn.Module):
    def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
        super().__init__()
        self.text_encoder=text_encoder
        self.graph_encoder=graph_encoder
        self.cross=CrossAttentionBlock(dim, n_heads, p)
        self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
        self.shared_head=FusionClassifier(dim, n_labels, p)
    def forward(self, smiles_list, desc_feats):
        tt, tm = self.text_encoder(smiles_list, max_length=256)
        gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
        tta = self.cross(tt.to(device), tm.to(device), gn.to(device), gm.to(device))
        de  = self.desc_mlp(desc_feats.to(device))
        text_pool  = masked_mean(tta, tm.to(device), 1)
        graph_pool = masked_mean(gn.to(device),  gm.to(device), 1)
        fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
        logits = self.shared_head(fused)
        return logits, fused

# Build model & load checkpoint
text_encoder = ChemBERTaEncoder().to(device)
graph_encoder= GraphGINEncoder().to(device)
v7_shared    = V7FusionModel(text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, n_labels=len(LABEL_NAMES)).to(device)
ckpt = torch.load(CKPT_BEST, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()

# Specialist heads (same as trained)
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _load_best_head(label: str) -> nn.Module:
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands: raise FileNotFoundError(f"No trained heads for label {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ck["model"], strict=True)
    head.eval()
    return head

HEADS: Dict[str, nn.Module] = {lbl: _load_best_head(lbl) for lbl in LABEL_NAMES}

# Descriptors for ad-hoc inputs: standardized zeros (keeps it simple & robust)
def prepare_desc_matrix(smiles_list: List[str]) -> torch.Tensor:
    Z = np.zeros((len(smiles_list), DESC_IN_DIM), dtype=np.float32)
    return torch.tensor(Z, dtype=torch.float32, device=device)

# Normalize SMARTS→SMILES if needed
def normalize_smiles_or_smarts(s: str) -> str:
    if not isinstance(s, str): s = str(s)
    mol = Chem.MolFromSmiles(s)
    if mol: return Chem.MolToSmiles(mol)
    q = Chem.MolFromSmarts(s)
    if q:
        try:
            smi = Chem.MolToSmiles(q)
            return smi if smi else s
        except Exception:
            return s
    return s

@torch.no_grad()
def fused_from_smiles(smiles_list: List[str]) -> torch.Tensor:
    smiles_list = [normalize_smiles_or_smarts(s) for s in smiles_list]
    desc = prepare_desc_matrix(smiles_list)
    logits_sh, fused = v7_shared(smiles_list, desc)  # logits not used here directly
    return fused  # (B,768)

def predict_one_blend(smi: str, mode: str = "fbeta15", topk: int = 5):
    """
    Blended prediction for one SMILES/SMARTS using:
      prob_blend = alpha*P_spec + (1-alpha)*P_shared
    Thresholds taken from thresholds_blend.json for chosen mode ("f1" or "fbeta15").
    Prints a clean summary and returns a dict[label]->details.
    """
    assert mode in ("f1","fbeta15")
    fused = fused_from_smiles([smi])
    x = fused[0:1]

    # Shared logits and calibrated probs
    with torch.no_grad():
        logits_shared = v7_shared.shared_head(x).detach().cpu().numpy()[0]  # (12,)

    rec = {}
    for j, lbl in enumerate(LABEL_NAMES):
        # Specialist prob (with its temperature)
        with torch.no_grad():
            logit_spec = HEADS[lbl](x).item()
        T_spec   = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
        p_spec   = 1. / (1. + math.e**(-logit_spec / T_spec))

        # Shared prob (with shared temperature)
        T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
        p_shared = 1. / (1. + math.e**(-float(logits_shared[j]) / T_shared))

        # Blend
        p_blend = ALPHA * p_spec + (1.0 - ALPHA) * p_shared

        # Threshold
        th = thr_blend[lbl]["th_fbeta15"] if mode=="fbeta15" else thr_blend[lbl]["th_f1"]
        rec[lbl] = {
            "prob_spec": float(p_spec),
            "prob_shared": float(p_shared),
            "prob_blend": float(p_blend),
            "threshold": float(th),
            "decision": bool(p_blend >= float(th)),
        }

    # Pretty print
    print("\nSMILES/SMARTS:", smi, f"(alpha={ALPHA:.2f}, mode={mode})")
    top = sorted([(lbl, d["prob_blend"], d["decision"]) for lbl, d in rec.items()],
                 key=lambda z: z[1], reverse=True)[:topk]
    for lbl, p, dec in top:
        th = rec[lbl]["threshold"]
        print(f"  {lbl:12s}  prob_blend={p:.3f}  th={th:.3f}  → pred={int(dec)}")
    pos = [lbl for lbl, d in rec.items() if d["decision"]]
    print("  Positives:", (", ".join(sorted(pos)) if pos else "none"))
    return rec

print("✅ Blend test rig ready. Example:")


✅ Blend test rig ready. Example:


In [None]:
predict_one_blend("O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1", mode="fbeta15", topk=12)


SMILES/SMARTS: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1 (alpha=0.80, mode=fbeta15)
  NR-AhR        prob_blend=0.700  th=0.642  → pred=1
  SR-MMP        prob_blend=0.680  th=0.589  → pred=1
  SR-ARE        prob_blend=0.588  th=0.528  → pred=1
  NR-ER         prob_blend=0.556  th=0.480  → pred=1
  SR-p53        prob_blend=0.551  th=0.478  → pred=1
  NR-ER-LBD     prob_blend=0.499  th=0.589  → pred=0
  NR-Aromatase  prob_blend=0.498  th=0.474  → pred=1
  SR-ATAD5      prob_blend=0.479  th=0.483  → pred=0
  NR-PPAR-gamma  prob_blend=0.478  th=0.427  → pred=1
  SR-HSE        prob_blend=0.460  th=0.459  → pred=1
  NR-AR         prob_blend=0.432  th=0.653  → pred=0
  NR-AR-LBD     prob_blend=0.423  th=0.621  → pred=0
  Positives: NR-AhR, NR-Aromatase, NR-ER, NR-PPAR-gamma, SR-ARE, SR-HSE, SR-MMP, SR-p53


{'NR-AR': {'prob_spec': 0.5297900819654926,
  'prob_shared': 0.039244673619722704,
  'prob_blend': 0.43168100029633866,
  'threshold': 0.653282642364502,
  'decision': False},
 'NR-AR-LBD': {'prob_spec': 0.527085290472785,
  'prob_shared': 0.007392770345504927,
  'prob_blend': 0.42314678644732895,
  'threshold': 0.6206690669059753,
  'decision': False},
 'NR-AhR': {'prob_spec': 0.6696848171077073,
  'prob_shared': 0.8209122313828833,
  'prob_blend': 0.6999302999627425,
  'threshold': 0.6417197585105896,
  'decision': True},
 'NR-Aromatase': {'prob_spec': 0.5691290816416109,
  'prob_shared': 0.21573499976865418,
  'prob_blend': 0.4984502652670196,
  'threshold': 0.4737248420715332,
  'decision': True},
 'NR-ER': {'prob_spec': 0.5319725845509377,
  'prob_shared': 0.6530290641575132,
  'prob_blend': 0.5561838804722528,
  'threshold': 0.4802268147468567,
  'decision': True},
 'NR-ER-LBD': {'prob_spec': 0.5796462377733185,
  'prob_shared': 0.17392469045136985,
  'prob_blend': 0.498501928308

In [None]:
predict_one_blend("O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1", mode="f1", topk=12)


SMILES/SMARTS: O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1 (alpha=0.80, mode=f1)
  NR-AhR        prob_blend=0.700  th=0.709  → pred=0
  SR-MMP        prob_blend=0.680  th=0.589  → pred=1
  SR-ARE        prob_blend=0.588  th=0.528  → pred=1
  NR-ER         prob_blend=0.556  th=0.547  → pred=1
  SR-p53        prob_blend=0.551  th=0.513  → pred=1
  NR-ER-LBD     prob_blend=0.499  th=0.589  → pred=0
  NR-Aromatase  prob_blend=0.498  th=0.564  → pred=0
  SR-ATAD5      prob_blend=0.479  th=0.483  → pred=0
  NR-PPAR-gamma  prob_blend=0.478  th=0.441  → pred=1
  SR-HSE        prob_blend=0.460  th=0.472  → pred=0
  NR-AR         prob_blend=0.432  th=0.653  → pred=0
  NR-AR-LBD     prob_blend=0.423  th=0.621  → pred=0
  Positives: NR-ER, NR-PPAR-gamma, SR-ARE, SR-MMP, SR-p53


{'NR-AR': {'prob_spec': 0.5297900819654926,
  'prob_shared': 0.039244673619722704,
  'prob_blend': 0.43168100029633866,
  'threshold': 0.653282642364502,
  'decision': False},
 'NR-AR-LBD': {'prob_spec': 0.527085290472785,
  'prob_shared': 0.007392770345504927,
  'prob_blend': 0.42314678644732895,
  'threshold': 0.6206690669059753,
  'decision': False},
 'NR-AhR': {'prob_spec': 0.6696848171077073,
  'prob_shared': 0.8209122313828833,
  'prob_blend': 0.6999302999627425,
  'threshold': 0.7087583541870117,
  'decision': False},
 'NR-Aromatase': {'prob_spec': 0.5691290816416109,
  'prob_shared': 0.21573499976865418,
  'prob_blend': 0.4984502652670196,
  'threshold': 0.5641032457351685,
  'decision': False},
 'NR-ER': {'prob_spec': 0.5319725845509377,
  'prob_shared': 0.6530290641575132,
  'prob_blend': 0.5561838804722528,
  'threshold': 0.547207772731781,
  'decision': True},
 'NR-ER-LBD': {'prob_spec': 0.5796462377733185,
  'prob_shared': 0.17392469045136985,
  'prob_blend': 0.49850192830

## phase 6 (Evaluation)

### 1: Ground truth and fused features

In [None]:
# =========================
# Phase 6 — Evaluation (robust to restarts; will rebuild fused features if missing)
# =========================
import os, json, math
from pathlib import Path
from typing import List, Dict, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import (
    average_precision_score, roc_auc_score,
    precision_recall_curve, roc_curve
)

# ---- Config ----
BASE         = Path("v7")
PREP_DIR     = BASE / "data" / "prepared"
DESC_DIR     = BASE / "data" / "descriptors"
FUSED_DIR    = BASE / "data" / "fused"
MODEL_DIR    = BASE / "model"
CAL_DIR      = MODEL_DIR / "calibration"
ENS_DIR      = MODEL_DIR / "ensembles"
EVAL_DIR     = BASE / "eval"
PLOT_PR_DIR  = EVAL_DIR / "plots" / "pr"
PLOT_REL_DIR = EVAL_DIR / "plots" / "reliability"

EVAL_DIR.mkdir(parents=True, exist_ok=True)
PLOT_PR_DIR.mkdir(parents=True, exist_ok=True)
PLOT_REL_DIR.mkdir(parents=True, exist_ok=True)
FUSED_DIR.mkdir(parents=True, exist_ok=True)

# Choose path: "specialist" OR "blend"
EVAL_MODE   = "blend"       # "specialist" or "blend"
THRESH_MODE = "fbeta15"     # "fbeta15" or "f1"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Load manifest & test blobs ----
mani_path = PREP_DIR / "dataset_manifest.json"
assert mani_path.exists(), f"Missing manifest: {mani_path}"
mani      = json.loads(mani_path.read_text())
LABELS    = mani["labels"]
N_LABELS  = len(LABELS)
DESC_IN_DIM = int(mani["n_features"])  # 208

blob_path = PREP_DIR / "test.npz"
assert blob_path.exists(), f"Missing test blob: {blob_path}"
blob  = np.load(blob_path, allow_pickle=True)
smiles= [str(s) for s in blob["smiles"].tolist()]
Yte   = blob["Y"].astype(np.float32)            # (N, L)
Mte   = blob["y_missing_mask"].astype(bool)     # (N, L) True where missing
N     = Yte.shape[0]

# ---- Helper: rebuild fused features if absent ----------------------------
def ensure_fused(split: str = "test") -> np.ndarray:
    """Return fused features for split. If missing, recompute and cache."""
    path = FUSED_DIR / f"{split}_fused.npy"
    if path.exists():
        return np.load(path).astype(np.float32)

    print(f"[Rebuild] {path} not found → recomputing {split} fused features...")

    # 1) Load descriptor transformer
    from joblib import load as joblib_load
    imputer = joblib_load(DESC_DIR / "imputer.joblib")
    scaler  = joblib_load(DESC_DIR / "scaler.joblib")

    # 2) RDKit descriptor function that matches training order via feature_names.txt
    from rdkit import Chem
    from rdkit.Chem import Descriptors as RDDesc

    feat_list_path = DESC_DIR / "feature_names.txt"
    assert feat_list_path.exists(), f"Missing feature_names.txt at {feat_list_path}"
    feature_names = [ln.strip() for ln in feat_list_path.read_text().splitlines() if ln.strip()]
    # Build callables dict for RDKit Descriptors.*
    rd_fns = {name: getattr(RDDesc, name, None) for name in feature_names}

    def compute_rdkit_descriptors_for_smiles(smiles_list: List[str]) -> np.ndarray:
        rows = []
        for smi in smiles_list:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                # keep row length consistent; fill with NaN
                rows.append([np.nan]*len(feature_names))
                continue
            vals = []
            for name in feature_names:
                fn = rd_fns.get(name, None)
                if fn is None:
                    vals.append(np.nan)
                    continue
                try:
                    v = fn(mol)
                except Exception:
                    v = np.nan
                vals.append(float(v) if (v is not None and np.isfinite(v)) else np.nan)
            rows.append(vals)
        return np.asarray(rows, dtype=np.float32)

    # 3) Build shared model (text+graph encoders + desc MLP) and load checkpoint
    CKPT_BEST = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
    assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"

    # -- Text encoder (ChemBERTa) --
    from transformers import AutoTokenizer, AutoModel
    class ChemBERTaEncoder(nn.Module):
        def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
            super().__init__()
            self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
            self.backbone  = AutoModel.from_pretrained(ckpt_name)
            self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
            self.ln = nn.LayerNorm(fusion_dim)
        def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
            enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                                 max_length=max_length, add_special_tokens=add_special_tokens,
                                 return_tensors="pt")
            input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
            out = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
            toks = self.ln(self.proj(out))  # (B,L,256)
            return toks, attention_mask.to(dtype=torch.int32)

    # -- Graph encoder (names match checkpoint) --
    ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
    def _one_hot(v, choices):
        z = [0]*len(choices)
        if v in choices: z[choices.index(v)] = 1
        return z
    def _bucket_oh(v, lo, hi):
        buckets = list(range(lo, hi+1))
        o = [0]*(len(buckets)+1)
        idx = v - lo
        o[idx if 0 <= idx < len(buckets) else -1] = 1
        return o
    def _atom_feat(atom):
        hybs = [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
        chir = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
                Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
                Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
                Chem.rdchem.ChiralType.CHI_OTHER]
        sym = atom.GetSymbol()
        feat = _one_hot(sym if sym in ATOM_LIST else "other", ATOM_LIST+["other"])
        feat += _bucket_oh(atom.GetDegree(), 0, 5)
        feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
        feat += (_one_hot(atom.GetHybridization(), hybs)+[0])
        feat += [int(atom.GetIsAromatic())]
        feat += [int(atom.IsInRing())]
        feat += _one_hot(atom.GetChiralTag(), chir)
        feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
        feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
        feat += [atom.GetMass()/200.0]
        return feat  # ~51 dims
    def _smiles_to_graph(smi, max_nodes=128):
        mol = Chem.MolFromSmiles(smi)
        if mol is None or mol.GetNumAtoms() == 0:
            return np.zeros((0,0), dtype=np.float32), np.zeros((0,0), dtype=np.float32)
        feats = [_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())]
        x = np.asarray(feats, dtype=np.float32)
        N = mol.GetNumAtoms()
        adj = np.zeros((N, N), dtype=np.float32)
        for b in mol.GetBonds():
            i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
            adj[i, j] = 1.0; adj[j, i] = 1.0
        if N > max_nodes:
            x = x[:max_nodes]; adj = adj[:max_nodes, :max_nodes]
        return x, adj
    def _collate_graphs(smiles_batch, max_nodes=128):
        graphs = [_smiles_to_graph(s) for s in smiles_batch]
        Nmax = max([g[0].shape[0] for g in graphs] + [1])
        Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
        B = len(graphs)
        X = np.zeros((B, Nmax, Fnode), dtype=np.float32)
        A = np.zeros((B, Nmax, Nmax), dtype=np.float32)
        M = np.zeros((B, Nmax), dtype=np.int64)
        for i, (x, a) in enumerate(graphs):
            n = x.shape[0]
            if n == 0: continue
            X[i, :n, :] = x
            A[i, :n, :n] = a
            M[i, :n] = 1
        return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

    class GINLayer(nn.Module):
        def __init__(self, h=256, p=0.1):
            super().__init__()
            self.eps = nn.Parameter(torch.tensor(0.0))
            self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
        def forward(self, x, adj, mask):
            out = (1.0 + self.eps) * x + torch.matmul(adj, x)
            out = self.mlp(out)
            return out * mask.unsqueeze(-1).to(out.dtype)

    class GraphGINEncoder(nn.Module):
        def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
            super().__init__()
            self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
            self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
            self.out_ln = nn.LayerNorm(hidden_dim)  # name matches checkpoint
        def forward(self, smiles_list: List[str], max_nodes=128):
            X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
            h = self.inp(X)
            for layer in self.layers:
                h = layer(h, A, M)
            return self.out_ln(h), M.to(dtype=torch.int32)

    # -- Fusion parts --
    def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
        mask = mask.to(dtype=x.dtype, device=x.device)
        denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
        return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom
    class CrossAttentionBlock(nn.Module):
        def __init__(self, dim=256, n_heads=4, p=0.1):
            super().__init__()
            self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
            self.ln  = nn.LayerNorm(dim)
            self.do  = nn.Dropout(p)
        def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
            Q = text_tokens.transpose(0,1); K = graph_nodes.transpose(0,1); V = graph_nodes.transpose(0,1)
            kpm = (graph_mask == 0)
            attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
            attn = attn.transpose(0,1)
            return self.ln(text_tokens + self.do(attn))
    class DescriptorMLP(nn.Module):
        def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
                nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
            )
        def forward(self, x): return self.net(x)
    class FusionClassifier(nn.Module):
        # name 'mlp' to match checkpoint
        def __init__(self, dim=256, n_labels=N_LABELS, p=0.1):
            super().__init__()
            self.mlp = nn.Sequential(
                nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
                nn.Linear(dim*2, n_labels)
            )
        def forward(self, fused_vec): return self.mlp(fused_vec)
    class V7FusionModel(nn.Module):
        def __init__(self, text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, dim=256, n_labels=N_LABELS, n_heads=4, p=0.1):
            super().__init__()
            self.text_encoder=text_encoder
            self.graph_encoder=graph_encoder
            self.cross=CrossAttentionBlock(dim, n_heads, p)
            self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
            self.shared_head=FusionClassifier(dim, n_labels, p)
        def forward(self, smiles_list, desc_feats, return_fused=False):
            tt, tm = self.text_encoder(smiles_list, max_length=256)
            gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
            tta = self.cross(tt.to(device), tm.to(device), gn.to(device), gm.to(device))
            de  = self.desc_mlp(desc_feats.to(device))
            text_pool  = masked_mean(tta, tm.to(device), 1)
            graph_pool = masked_mean(gn.to(device),  gm.to(device), 1)
            fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
            logits = self.shared_head(fused)
            return (logits, fused) if return_fused else logits

    text_encoder = ChemBERTaEncoder().to(device)
    graph_encoder= GraphGINEncoder().to(device)
    model        = V7FusionModel(text_encoder, graph_encoder).to(device)
    ckpt = torch.load(CKPT_BEST, map_location=device)
    model.load_state_dict(ckpt["model"], strict=True)
    model.eval()

    # 4) Compute descriptors → impute/scale → fused features
    X_raw = compute_rdkit_descriptors_for_smiles(smiles)           # (N, 208 with NaNs)
    X_imp = imputer.transform(X_raw)
    X_std = scaler.transform(X_imp)
    desc_t= torch.tensor(X_std, dtype=torch.float32, device=device)

    fused_list = []
    B = 64
    for i in range(0, N, B):
        batch_smiles = smiles[i:i+B]
        logits, fused = model(batch_smiles, desc_t[i:i+B], return_fused=True)
        fused_list.append(fused.detach().cpu().numpy())
    fused_all = np.concatenate(fused_list, axis=0).astype(np.float32)
    np.save(path, fused_all)
    print(f"[Rebuild] Saved → {path}")
    return fused_all

# ---- Get fused test features (rebuild if missing) ----
X_fused = ensure_fused("test")
X_fused_t = torch.tensor(X_fused, dtype=torch.float32, device=device)

# ---- Load calibration/thresholds ----
temps_spec = json.loads((CAL_DIR / "temps.json").read_text())                # specialist temps
if EVAL_MODE == "specialist":
    thresholds_spec = json.loads((CAL_DIR / "thresholds.json").read_text())  # specialist thresholds
else:
    # Blend
    temps_shared   = json.loads((CAL_DIR / "temps_shared.json").read_text())
    blend_payload  = json.loads((CAL_DIR / "thresholds_blend.json").read_text())
    ALPHA          = float(blend_payload.get("alpha", 0.8))
    thresholds_blend = blend_payload["thresholds"]

# ---- Define heads (specialists) ----
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def load_best_head(label: str) -> nn.Module:
    # choose seed with highest best_ap
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(ck["model"], strict=True)
    head.eval()
    return head

HEADS = {lbl: load_best_head(lbl) for lbl in LABELS}

# ---- Shared head (for blend only): load just the classifier on fused
if EVAL_MODE == "blend":
    CKPT_BEST = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
    class SharedHeadOnly(nn.Module):
        def __init__(self, dim=256, n_labels=N_LABELS, p=0.1):
            super().__init__()
            self.mlp = nn.Sequential(
                nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
                nn.Linear(dim*2, n_labels)
            )
        def forward(self, fused):
            return self.mlp(fused)
    shared_head = SharedHeadOnly().to(device)
    ckpt = torch.load(CKPT_BEST, map_location=device)
    sh_state = {k.replace("shared_head.", ""): v for k,v in ckpt["model"].items() if k.startswith("shared_head.")}
    shared_head.load_state_dict(sh_state, strict=True)
    shared_head.eval()

# ---- Specialist logits on fused test (fast) ----
with torch.no_grad():
    spec_logits = torch.zeros((N, N_LABELS), dtype=torch.float32, device=device)
    for j, lbl in enumerate(LABELS):
        spec_logits[:, j] = HEADS[lbl](X_fused_t)
spec_logits = spec_logits.cpu().numpy()

# ---- Shared logits on fused test (for blend path) ----
if EVAL_MODE == "blend":
    with torch.no_grad():
        shared_logits = shared_head(X_fused_t).cpu().numpy()
else:
    shared_logits = None

# ---- Build probability matrix according to EVAL_MODE ----
def sigmoid(x): return 1.0/(1.0+np.exp(-x))

if EVAL_MODE == "specialist":
    PROBS = np.zeros_like(spec_logits, dtype=np.float32)
    for j, lbl in enumerate(LABELS):
        T = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
        PROBS[:, j] = sigmoid(spec_logits[:, j] / T)
else:
    PROBS = np.zeros_like(spec_logits, dtype=np.float32)
    for j, lbl in enumerate(LABELS):
        T_spec   = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
        T_shared = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
        p_spec   = sigmoid(spec_logits[:, j]   / T_spec)
        p_shared = sigmoid(shared_logits[:, j] / T_shared)
        PROBS[:, j] = np.clip(ALPHA * p_spec + (1-ALPHA) * p_shared, 0.0, 1.0)

# ---- Helper: ECE & reliability curve ----
def reliability_and_ece(y_true, y_prob, n_bins=15):
    bins = np.linspace(0.0, 1.0, n_bins+1)
    bin_ids = np.digitize(y_prob, bins) - 1
    bin_acc, bin_conf, bin_count = [], [], []
    ece = 0.0
    for b in range(n_bins):
        mask = (bin_ids == b)
        n = mask.sum()
        if n == 0:
            bin_acc.append(np.nan); bin_conf.append(np.nan); bin_count.append(0)
            continue
        p = y_prob[mask]; t = y_true[mask]
        acc = t.mean(); conf = p.mean()
        bin_acc.append(acc); bin_conf.append(conf); bin_count.append(n)
        ece += (n/len(y_true)) * abs(acc - conf)
    return (bins, np.array(bin_acc), np.array(bin_conf), np.array(bin_count)), float(ece)

# ---- Compute metrics per label & global ----
rows = []
tp_micro = fp_micro = fn_micro = 0
macro_ap_vals = []
macro_roc_vals = []

for j, lbl in enumerate(LABELS):
    valid = ~Mte[:, j]
    y = Yte[valid, j].astype(int)
    p = PROBS[valid, j]

    # AUCs
    ap = float(average_precision_score(y, p)) if valid.sum() > 0 else float("nan")
    macro_ap_vals.append(ap)
    try:
        roc = float(roc_auc_score(y, p))
    except Exception:
        roc = float("nan")
    macro_roc_vals.append(roc)

    # Operating thresholds (from saved calibration)
    if EVAL_MODE == "specialist":
        th_f1  = float(json.loads((CAL_DIR / "thresholds.json").read_text())[lbl]["th_f1"])
        th_fb  = float(json.loads((CAL_DIR / "thresholds.json").read_text())[lbl]["th_fbeta15"])
    else:
        th_f1  = float(thresholds_blend[lbl]["th_f1"])
        th_fb  = float(thresholds_blend[lbl]["th_fbeta15"])

    def prf_at_thresh(th):
        pred = (p >= th).astype(int)
        tp = int(((pred==1) & (y==1)).sum())
        fp = int(((pred==1) & (y==0)).sum())
        fn = int(((pred==0) & (y==1)).sum())
        prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
        rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
        f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
        return tp, fp, fn, prec, rec, f1

    tp1, fp1, fn1, pr1, rc1, f1 = prf_at_thresh(th_f1)
    tpb, fpb, fnb, prb, rcb, fb = prf_at_thresh(th_fb)

    if THRESH_MODE == "f1":
        tp_micro += tp1; fp_micro += fp1; fn_micro += fn1
    else:
        tp_micro += tpb; fp_micro += fpb; fn_micro += fnb

    # Prevalence
    prev = float(y.mean()) if valid.sum() > 0 else float("nan")

    # ECE + save reliability plot
    (bins, acc, conf, counts), ece = reliability_and_ece(y, p, n_bins=15)
    plt.figure()
    mask = ~np.isnan(acc)
    plt.plot([0,1], [0,1], linestyle="--")
    if mask.any():
        plt.plot(conf[mask], acc[mask], marker="o")
    plt.xlabel("Mean predicted probability"); plt.ylabel("Fraction of positives")
    plt.title(f"Reliability: {lbl} (ECE={ece:.3f})")
    plt.tight_layout()
    plt.savefig(PLOT_REL_DIR / f"{lbl}.png", dpi=160); plt.close()

    # PR curve plot
    prec, rec, _ = precision_recall_curve(y, p)
    plt.figure()
    plt.step(rec, prec, where="post")
    plt.xlabel("Recall"); plt.ylabel("Precision")
    plt.title(f"PR curve: {lbl} (AP={ap:.3f})")
    plt.tight_layout()
    plt.savefig(PLOT_PR_DIR / f"{lbl}.png", dpi=160); plt.close()

    rows.append({
        "label": lbl,
        "n_valid": int(valid.sum()),
        "prevalence": prev,
        "ap": ap,
        "roc_auc": roc,
        "th_f1": th_f1,
        "prec@f1": pr1, "recall@f1": rc1, "f1": f1,
        "tp@f1": tp1, "fp@f1": fp1, "fn@f1": fn1,
        "th_fbeta15": th_fb,
        "prec@fbeta15": prb, "recall@fbeta15": rcb, "f_beta15": fb,
        "tp@fbeta15": tpb, "fp@fbeta15": fpb, "fn@fbeta15": fnb,
        "ece": ece
    })

# ---- Global summaries ----
macro_pr_auc  = float(np.nanmean([r["ap"] for r in rows]))
macro_roc_auc = float(np.nanmean([r["roc_auc"] for r in rows]))
micro_prec = tp_micro/(tp_micro+fp_micro) if (tp_micro+fp_micro)>0 else 0.0
micro_rec  = tp_micro/(tp_micro+fn_micro) if (tp_micro+fn_micro)>0 else 0.0
micro_f1   = (2*micro_prec*micro_rec)/(micro_prec+micro_rec) if (micro_prec+micro_rec)>0 else 0.0

# Cardinality (avg #positive labels per sample) – true vs predicted at chosen operating mode
if THRESH_MODE == "f1":
    if EVAL_MODE == "specialist":
        thobj = json.loads((CAL_DIR / "thresholds.json").read_text())
        THS = np.array([float(thobj[l]["th_f1"]) for l in LABELS], dtype=np.float32)
    else:
        THS = np.array([float(thresholds_blend[l]["th_f1"]) for l in LABELS], dtype=np.float32)
else:
    if EVAL_MODE == "specialist":
        thobj = json.loads((CAL_DIR / "thresholds.json").read_text())
        THS = np.array([float(thobj[l]["th_fbeta15"]) for l in LABELS], dtype=np.float32)
    else:
        THS = np.array([float(thresholds_blend[l]["th_fbeta15"]) for l in LABELS], dtype=np.float32)

pred_bin = (PROBS >= THS.reshape(1, -1)).astype(int)
pred_bin[Mte] = 0
true_bin = Yte.copy().astype(int)
true_bin[Mte] = 0

avg_true_card = float(true_bin.sum(axis=1).mean())
avg_pred_card = float(pred_bin.sum(axis=1).mean())
card_err      = float(avg_pred_card - avg_true_card)

# ---- Save reports ----
per_label_df = pd.DataFrame(rows)
per_label_csv = EVAL_DIR / "per_label_metrics.csv"
per_label_df.to_csv(per_label_csv, index=False)

summary = {
    "eval_mode": EVAL_MODE,
    "threshold_mode": THRESH_MODE,
    "n_test": int(N),
    "macro_pr_auc": macro_pr_auc,
    "macro_roc_auc": macro_roc_auc,
    "micro_precision": micro_prec,
    "micro_recall": micro_rec,
    "micro_f1": micro_f1,
    "avg_true_cardinality": avg_true_card,
    "avg_pred_cardinality": avg_pred_card,
    "cardinality_error": card_err,
    "plots": {
        "pr_curves_dir": str(PLOT_PR_DIR),
        "reliability_dir": str(PLOT_REL_DIR)
    },
}
(EVAL_DIR / "summary.json").write_text(json.dumps(summary, indent=2))

print("✅ Evaluation complete.")
print(f" Per-label CSV  → {per_label_csv}")
print(f" Summary JSON   → {EVAL_DIR / 'summary.json'}")
print(f" PR curves      → {PLOT_PR_DIR}")
print(f" Reliability    → {PLOT_REL_DIR}")
print("\nGlobal (test):")
for k in ["eval_mode","threshold_mode","n_test","macro_pr_auc","macro_roc_auc","micro_precision","micro_recall","micro_f1","avg_true_cardinality","avg_pred_cardinality","cardinality_error"]:
    print(f"  {k}: {summary[k]}")


[Rebuild] v7\data\fused\test_fused.npy not found → recomputing test fused features...
[Rebuild] Saved → v7\data\fused\test_fused.npy


  def sigmoid(x): return 1.0/(1.0+np.exp(-x))


✅ Evaluation complete.
 Per-label CSV  → v7\eval\per_label_metrics.csv
 Summary JSON   → v7\eval\summary.json
 PR curves      → v7\eval\plots\pr
 Reliability    → v7\eval\plots\reliability

Global (test):
  eval_mode: blend
  threshold_mode: fbeta15
  n_test: 783
  macro_pr_auc: 0.2179368491745162
  macro_roc_auc: 0.7544820729548566
  micro_precision: 0.2006872852233677
  micro_recall: 0.5793650793650794
  micro_f1: 0.29811128126595204
  avg_true_cardinality: 0.6436781609195402
  avg_pred_cardinality: 1.8582375478927202
  cardinality_error: 1.21455938697318


  true_bin = Yte.copy().astype(int)


### 2: Specialist vs Blend comparison

In [None]:
# =========================
# Phase 6 — Cell 2: Specialist vs Blend comparison (restart-proof)
# =========================
import os, json, math
from pathlib import Path
from typing import List, Dict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

# ---- Paths & basic setup ----
BASE         = Path("v7")
PREP_DIR     = BASE / "data" / "prepared"
DESC_DIR     = BASE / "data" / "descriptors"
FUSED_DIR    = BASE / "data" / "fused"
MODEL_DIR    = BASE / "model"
CAL_DIR      = MODEL_DIR / "calibration"
ENS_DIR      = MODEL_DIR / "ensembles"
EVAL_DIR     = BASE / "eval"
EVAL_DIR.mkdir(parents=True, exist_ok=True)
FUSED_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Manifest & test blobs ----
mani_path = PREP_DIR / "dataset_manifest.json"
assert mani_path.exists(), f"Missing manifest: {mani_path}"
mani = json.loads(mani_path.read_text())
LABELS = mani["labels"]
N_LABELS = len(LABELS)
DESC_IN_DIM = int(mani["n_features"])

blob_path = PREP_DIR / "test.npz"
assert blob_path.exists(), f"Missing test blob: {blob_path}"
blob   = np.load(blob_path, allow_pickle=True)
smiles = [str(s) for s in blob["smiles"].tolist()]
Yte    = blob["Y"].astype(np.float32)
Mte    = blob["y_missing_mask"].astype(bool)
N      = Yte.shape[0]

# ---- Ensure fused features (rebuild if missing) ----
def ensure_fused(split="test") -> np.ndarray:
    path = FUSED_DIR / f"{split}_fused.npy"
    if path.exists():
        return np.load(path).astype(np.float32)

    print(f"[Rebuild] {path} not found → recomputing {split} fused features...")

    # Load descriptor imputer/scaler
    from joblib import load as joblib_load
    imp_path = DESC_DIR / "imputer.joblib"
    scl_path = DESC_DIR / "scaler.joblib"
    assert imp_path.exists() and scl_path.exists(), "Missing imputer/scaler joblib files."
    imputer = joblib_load(imp_path)
    scaler  = joblib_load(scl_path)

    # Prepare RDKit 208 descriptors in the SAME ORDER as training
    from rdkit import Chem
    from rdkit.Chem import Descriptors as RDDesc
    feat_names_file = DESC_DIR / "feature_names.txt"
    assert feat_names_file.exists(), f"Missing {feat_names_file}"
    feat_names = [ln.strip() for ln in feat_names_file.read_text().splitlines() if ln.strip()]
    rd_fns = {name: getattr(RDDesc, name, None) for name in feat_names}

    def rdkit_feats(smis: List[str]) -> np.ndarray:
        rows = []
        for s in smis:
            m = Chem.MolFromSmiles(s)
            if m is None:
                rows.append([np.nan]*len(feat_names)); continue
            vals = []
            for name in feat_names:
                fn = rd_fns.get(name, None)
                try:
                    v = float(fn(m)) if fn is not None else np.nan
                except Exception:
                    v = np.nan
                vals.append(v if np.isfinite(v) else np.nan)
            rows.append(vals)
        return np.asarray(rows, dtype=np.float32)

    # Build shared fusion model (text+graph+desc) & load checkpoint
    CKPT = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
    assert CKPT.exists(), f"Missing shared checkpoint: {CKPT}"

    from transformers import AutoTokenizer, AutoModel
    class ChemBERTaEncoder(nn.Module):
        def __init__(self, ckpt="seyonec/ChemBERTa-zinc-base-v1", dim=256, p=0.1):
            super().__init__()
            self.tok = AutoTokenizer.from_pretrained(ckpt)
            self.m   = AutoModel.from_pretrained(ckpt)
            self.proj= nn.Sequential(nn.Dropout(p), nn.Linear(self.m.config.hidden_size, dim))
            self.ln  = nn.LayerNorm(dim)
        def forward(self, smis, max_length=256):
            enc = self.tok(list(smis), padding=True, truncation=True, max_length=max_length, return_tensors="pt")
            out = self.m(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)).last_hidden_state
            return self.ln(self.proj(out)), enc["attention_mask"].to(device, dtype=torch.int32)

    ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
    def _one_hot(v, C): z=[0]*len(C); z[C.index(v) if v in C else -1]=1; return z
    def _bucket(v, lo, hi):
        buckets=list(range(lo,hi+1)); o=[0]*(len(buckets)+1); i=v-lo; o[i if 0<=i<len(buckets) else -1]=1; return o
    def _atom_feat(a):
        hyb=[Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
        chir=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
        f=_one_hot(a.GetSymbol(), ATOM_LIST+["other"])
        f+=_bucket(a.GetDegree(),0,5)+_bucket(a.GetFormalCharge(),-2,2)+(_one_hot(a.GetHybridization(),hyb)+[0])
        f+=[int(a.GetIsAromatic()), int(a.IsInRing())]+_one_hot(a.GetChiralTag(),chir)
        f+=_bucket(a.GetTotalNumHs(True),0,4)+_bucket(a.GetTotalValence(),0,5)+[a.GetMass()/200.0]
        return f
    def _smiles_to_graph(smi, max_nodes=128):
        m = Chem.MolFromSmiles(smi)
        if m is None or m.GetNumAtoms()==0: return np.zeros((0,0),np.float32), np.zeros((0,0),np.float32)
        X = np.asarray([_atom_feat(m.GetAtomWithIdx(i)) for i in range(m.GetNumAtoms())], np.float32)
        N = m.GetNumAtoms(); A = np.zeros((N,N),np.float32)
        for b in m.GetBonds(): i,j=b.GetBeginAtomIdx(), b.GetEndAtomIdx(); A[i,j]=A[j,i]=1.0
        if N>max_nodes: X=X[:max_nodes]; A=A[:max_nodes,:max_nodes]
        return X,A
    def _collate(smis, max_nodes=128):
        G=[_smiles_to_graph(s) for s in smis]
        Nmax=max([g[0].shape[0] for g in G]+[1]); F=G[0][0].shape[1] if G[0][0].size>0 else 51; B=len(G)
        X=np.zeros((B,Nmax,F),np.float32); A=np.zeros((B,Nmax,Nmax),np.float32); M=np.zeros((B,Nmax),np.int64)
        for i,(x,a) in enumerate(G):
            n=x.shape[0]
            if n==0: continue
            X[i,:n,:]=x; A[i,:n,:n]=a; M[i,:n]=1
        return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

    class GINLayer(nn.Module):
        def __init__(self,h=256,p=0.1): super().__init__(); self.eps=nn.Parameter(torch.tensor(0.0)); self.mlp=nn.Sequential(nn.Linear(h,h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
        def forward(self,x,A,M): return self.mlp((1.0+self.eps)*x + torch.matmul(A,x)) * M.unsqueeze(-1).to(x.dtype)
    class GraphGINEncoder(nn.Module):
        def __init__(self,node_in_dim=51,h=256,L=4,p=0.1): super().__init__(); self.inp=nn.Sequential(nn.Linear(node_in_dim,h), nn.GELU(), nn.Dropout(p)); self.layers=nn.ModuleList([GINLayer(h,p) for _ in range(L)]); self.out_ln=nn.LayerNorm(h)
        def forward(self,smis,max_nodes=128):
            X,A,M=_collate(smis,max_nodes); h=self.inp(X)
            for L in self.layers: h=L(h,A,M)
            return self.out_ln(h), M.to(dtype=torch.int32)

    def masked_mean(x, m, dim): m=m.to(dtype=x.dtype, device=x.device); denom=m.sum(dim=dim,keepdim=True).clamp(min=1.0); return (x*m.unsqueeze(-1)).sum(dim=dim)/denom
    class CrossAttentionBlock(nn.Module):
        def __init__(self,dim=256,heads=4,p=0.1): super().__init__(); self.mha=nn.MultiheadAttention(dim, heads, dropout=p, batch_first=False); self.ln=nn.LayerNorm(dim); self.do=nn.Dropout(p)
        def forward(self,T,TM,G,GM): Q=T.transpose(0,1); K=G.transpose(0,1); V=G.transpose(0,1); kpm=(GM==0); A,_=self.mha(Q,K,V,key_padding_mask=kpm); return self.ln(T+self.do(A.transpose(0,1)))
    class DescriptorMLP(nn.Module):
        def __init__(self,inp,dim=256,p=0.1): super().__init__(); self.net=nn.Sequential(nn.Linear(inp,256), nn.GELU(), nn.Dropout(p), nn.Linear(256,dim), nn.GELU(), nn.Dropout(p))
        def forward(self,x): return self.net(x)
    class FusionClassifier(nn.Module):
        def __init__(self,dim=256,L=N_LABELS,p=0.1): super().__init__(); self.mlp=nn.Sequential(nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p), nn.Linear(dim*2, L))
        def forward(self,z): return self.mlp(z)
    class V7Shared(nn.Module):
        def __init__(self):
            super().__init__()
            self.text=ChemBERTaEncoder()
            self.graph=GraphGINEncoder()
            self.cross=CrossAttentionBlock()
            self.desc=DescriptorMLP(DESC_IN_DIM)
            self.head=FusionClassifier()
        def forward(self,smis,desc,return_fused=False):
            T,TM=self.text(smis,256); G,GM=self.graph(smis,128); Ta=self.cross(T,TM,G,GM); Dz=self.desc(desc)
            Tp=masked_mean(Ta,TM,1); Gp=masked_mean(G,GM,1); fused=torch.cat([Tp,Gp,Dz],-1)
            logits=self.head(fused); 
            return (logits, fused) if return_fused else logits

    v7 = V7Shared().to(device)
    v7.load_state_dict(torch.load(CKPT, map_location=device)["model"], strict=True)
    v7.eval()

    # descriptors → impute → scale
    X_raw = rdkit_feats(smiles)
    X_imp = imputer.transform(X_raw)
    X_std = scaler.transform(X_imp)
    desc_t = torch.tensor(X_std, dtype=torch.float32, device=device)

    fused = []
    B=64
    for i in range(0, N, B):
        _, f = v7(smiles[i:i+B], desc_t[i:i+B], return_fused=True)
        fused.append(f.detach().cpu().numpy())
    fused = np.concatenate(fused, 0).astype(np.float32)
    np.save(path, fused)
    print(f"[Rebuild] Saved → {path}")
    return fused

X_fused = ensure_fused("test")
X_fused_t = torch.tensor(X_fused, dtype=torch.float32, device=device)

# ---- Load calibration artifacts ----
temps_spec = json.loads((CAL_DIR / "temps.json").read_text())            # specialist temps
thr_spec   = json.loads((CAL_DIR / "thresholds.json").read_text())       # specialist thresholds
temps_sh   = json.loads((CAL_DIR / "temps_shared.json").read_text())     # shared temps
blend      = json.loads((CAL_DIR / "thresholds_blend.json").read_text()) # blend thresholds + alpha
ALPHA      = float(blend.get("alpha", 0.8))
thr_blend  = blend["thresholds"]

# ---- Robust specialist head loader (patch handles b1/b2/b3 vs block1/2/3) ----
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _remap_keys_if_needed(state_dict: dict) -> dict:
    needs = any(k.startswith(("b1.", "b2.", "b3.")) for k in state_dict.keys())
    if not needs: return state_dict
    remap = {}
    for k, v in state_dict.items():
        k2 = k.replace("b1.", "block1.").replace("b2.", "block2.").replace("b3.", "block3.")
        remap[k2] = v
    return remap

def load_best_head(label: str) -> nn.Module:
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/"), key=lambda p: p.name):
        mfile = sd / "metrics.json"
        if not mfile.exists(): continue
        try:
            ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
            cands.append((ap, sd))
        except Exception:
            pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label} under {ENS_DIR/label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    state = _remap_keys_if_needed(ck["model"])
    head.load_state_dict(state, strict=True)
    head.eval()
    return head

HEADS = {lbl: load_best_head(lbl) for lbl in LABELS}
print(f"✅ Loaded specialist heads: {len(HEADS)}/{len(LABELS)}")

# ---- Shared head (MLP on fused) for blend path ----
class SharedHeadMLP(nn.Module):
    def __init__(self, dim=256, n_labels=N_LABELS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, z): return self.mlp(z)

sh = SharedHeadMLP().to(device)
ckpt = torch.load(MODEL_DIR / "checkpoints" / "shared" / "best.pt", map_location=device)
sh_state = {k.replace("shared_head.", ""): v for k,v in ckpt["model"].items() if k.startswith("shared_head.")}
sh.load_state_dict(sh_state, strict=True)
sh.eval()

# ---- Build probability matrices for specialist & blend ----
with torch.no_grad():
    spec_logits = torch.stack([HEADS[lbl](X_fused_t) for lbl in LABELS], dim=1).cpu().numpy()  # (N,L)
    sh_logits   = sh(X_fused_t).cpu().numpy()                                                  # (N,L)

sigmoid = lambda x: 1.0/(1.0+np.exp(-x))

P_spec = np.zeros_like(spec_logits, np.float32)
P_blnd = np.zeros_like(spec_logits, np.float32)
for j, lbl in enumerate(LABELS):
    Ts = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
    Th = max(float(temps_sh.get(lbl, 1.0)),   1e-3)
    ps = sigmoid(spec_logits[:, j] / Ts)
    ph = sigmoid(sh_logits[:, j]   / Th)
    P_spec[:, j] = ps
    P_blnd[:, j] = np.clip(ALPHA*ps + (1-ALPHA)*ph, 0.0, 1.0)

# ---- Scoring helper ----
def score(PROBS: np.ndarray, thresholds_obj: Dict[str, dict], mode: str = "fbeta15") -> dict:
    tp=fp=fn=0
    ap_list=[]; roc_list=[]
    TH = np.array([float(thresholds_obj[l]["th_f1" if mode=="f1" else "th_fbeta15"]) for l in LABELS], np.float32)
    Pred = (PROBS >= TH.reshape(1, -1)).astype(int)
    Pred[Mte] = 0
    Truth = Yte.astype(int)
    Truth[Mte] = 0

    from sklearn.metrics import average_precision_score, roc_auc_score
    for j in range(N_LABELS):
        y = Truth[:, j]; p = PROBS[:, j]
        if y.max() != y.min():
            try: ap_list.append(float(average_precision_score(y, p)))
            except Exception: pass
            try: roc_list.append(float(roc_auc_score(y, p)))
            except Exception: pass
        tp += int(((Pred[:, j]==1) & (y==1)).sum())
        fp += int(((Pred[:, j]==1) & (y==0)).sum())
        fn += int(((Pred[:, j]==0) & (y==1)).sum())
    micro_prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
    micro_rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
    micro_f1   = (2*micro_prec*micro_rec)/(micro_prec+micro_rec) if (micro_prec+micro_rec)>0 else 0.0
    macro_pr   = float(np.nanmean(ap_list)) if ap_list else float("nan")
    macro_roc  = float(np.nanmean(roc_list)) if roc_list else float("nan")
    avg_true   = float(Truth.sum(1).mean())
    avg_pred   = float(Pred.sum(1).mean())
    return {
        "macro_pr_auc": macro_pr,
        "macro_roc_auc": macro_roc,
        "micro_precision": micro_prec,
        "micro_recall": micro_rec,
        "micro_f1": micro_f1,
        "avg_true_cardinality": avg_true,
        "avg_pred_cardinality": avg_pred,
        "cardinality_error": avg_pred - avg_true
    }

# ---- Evaluate all four configurations ----
res_spec_fb = score(P_spec, thr_spec,  mode="fbeta15")
res_blnd_fb = score(P_blnd, thr_blend, mode="fbeta15")
res_spec_f1 = score(P_spec, thr_spec,  mode="f1")
res_blnd_f1 = score(P_blnd, thr_blend, mode="f1")

table = pd.DataFrame([
    {"path":"specialist","mode":"fbeta15", **res_spec_fb},
    {"path":"blend",     "mode":"fbeta15", **res_blnd_fb},
    {"path":"specialist","mode":"f1",      **res_spec_f1},
    {"path":"blend",     "mode":"f1",      **res_blnd_f1},
])
table_path = EVAL_DIR / "compare_table.csv"
table.to_csv(table_path, index=False)
comp_json = {"fbeta15": {"specialist": res_spec_fb, "blend": res_blnd_fb},
             "f1":      {"specialist": res_spec_f1, "blend": res_blnd_f1}}
(EVAL_DIR / "compare_summary.json").write_text(json.dumps(comp_json, indent=2))

print("✅ Saved:")
print("  •", table_path)
print("  •", EVAL_DIR / "compare_summary.json")
print("\nQuick view:")
print(table.to_string(index=False))

✅ Loaded specialist heads: 12/12
✅ Saved:
  • v7\eval\compare_table.csv
  • v7\eval\compare_summary.json

Quick view:
      path    mode  macro_pr_auc  macro_roc_auc  micro_precision  micro_recall  micro_f1  avg_true_cardinality  avg_pred_cardinality  cardinality_error
specialist fbeta15      0.149778       0.709502         0.183659      0.503968  0.269210              0.643678              1.766284           1.122605
     blend fbeta15      0.157441       0.727127         0.200687      0.579365  0.298111              0.643678              1.858238           1.214559
specialist      f1      0.149778       0.709502         0.199797      0.390873  0.264430              0.643678              1.259259           0.615581
     blend      f1      0.157441       0.727127         0.230997      0.464286  0.308504              0.643678              1.293742           0.650064


  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  Truth = Yte.astype(int)
  Truth = Yte.astype(int)
  Truth = Yte.astype(int)
  Truth = Yte.astype(int)


### 3: Per-label diagnostics & policy proposal

In [None]:
# ===============================
# Phase 6 — Cell 3: Per-label diagnostics & policy proposal
# ===============================
import os, json, math
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score

# ------- Config -------
BASE         = Path("v7")
PREP_DIR     = BASE / "data" / "prepared"
DESC_DIR     = BASE / "data" / "descriptors"
FUSED_DIR    = BASE / "data" / "fused"
MODEL_DIR    = BASE / "model"
ENS_DIR      = MODEL_DIR / "ensembles"
CAL_DIR      = MODEL_DIR / "calibration"
EVAL_DIR     = BASE / "eval"
POL_DIR      = MODEL_DIR / "policy"
for p in [EVAL_DIR, POL_DIR, FUSED_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# How strict should we be about precision?
PREC_FLOOR = 0.55   # try 0.55–0.60; higher => fewer FPs, lower recall
DEFAULT_ALPHA = 0.8 # specialist weight used in your saved blend calibration

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------- Load manifests and artifacts -------
mani = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABELS = mani["labels"]; N_LABELS = len(LABELS)
DESC_IN_DIM = int(mani["n_features"])

# Test per-label metrics (from Phase 6 Cell 1 with EVAL_MODE="blend")
test_metrics_csv = EVAL_DIR / "per_label_metrics.csv"
assert test_metrics_csv.exists(), f"Missing {test_metrics_csv}. Please run Phase 6 Cell 1 first."
test_df = pd.read_csv(test_metrics_csv)

# Blended calibration
temps_spec   = json.loads((CAL_DIR / "temps.json").read_text())            # specialist temps
temps_shared = json.loads((CAL_DIR / "temps_shared.json").read_text())     # shared temps
blend_payload= json.loads((CAL_DIR / "thresholds_blend.json").read_text()) # has alpha + thresholds
ALPHA        = float(blend_payload.get("alpha", DEFAULT_ALPHA))
thr_blend    = blend_payload["thresholds"]  # per-label th_f1 / th_fbeta15

# ------- Ensure fused VAL features (rebuild if missing) -------
def ensure_fused(split: str) -> np.ndarray:
    out = FUSED_DIR / f"{split}_fused.npy"
    if out.exists():
        return np.load(out).astype(np.float32)

    print(f"[Rebuild] {out} not found → recomputing {split} fused features...")

    # Load split blob
    blob = np.load(PREP_DIR / f"{split}.npz", allow_pickle=True)
    smi  = [str(s) for s in blob["smiles"].tolist()]
    N    = len(smi)

    # Descriptor transformers
    from joblib import load as joblib_load
    imputer = joblib_load(DESC_DIR / "imputer.joblib")
    scaler  = joblib_load(DESC_DIR / "scaler.joblib")

    # RDKit descriptor list (same order as training)
    from rdkit import Chem
    from rdkit.Chem import Descriptors as RDDesc
    feat_names = [ln.strip() for ln in (DESC_DIR / "feature_names.txt").read_text().splitlines() if ln.strip()]
    rd_fns = {name: getattr(RDDesc, name, None) for name in feat_names}

    def rdkit_feats(smiles_list: List[str]) -> np.ndarray:
        rows = []
        for s in smiles_list:
            m = Chem.MolFromSmiles(s)
            if m is None:
                rows.append([np.nan]*len(feat_names)); continue
            vals=[]
            for name in feat_names:
                fn = rd_fns.get(name, None)
                try: v = float(fn(m)) if fn is not None else np.nan
                except Exception: v = np.nan
                vals.append(v if np.isfinite(v) else np.nan)
            rows.append(vals)
        return np.asarray(rows, dtype=np.float32)

    # Rebuild shared fusion model (text+graph+desc) → fused
    CKPT = MODEL_DIR / "checkpoints" / "shared" / "best.pt"
    assert CKPT.exists(), f"Missing shared checkpoint: {CKPT}"

    from transformers import AutoTokenizer, AutoModel
    class ChemBERTaEncoder(nn.Module):
        def __init__(self, ckpt="seyonec/ChemBERTa-zinc-base-v1", dim=256, p=0.1):
            super().__init__()
            self.tok = AutoTokenizer.from_pretrained(ckpt)
            self.m   = AutoModel.from_pretrained(ckpt)
            self.proj= nn.Sequential(nn.Dropout(p), nn.Linear(self.m.config.hidden_size, dim))
            self.ln  = nn.LayerNorm(dim)
        def forward(self, smis, max_length=256):
            enc = self.tok(list(smis), padding=True, truncation=True, max_length=max_length, return_tensors="pt")
            out = self.m(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)).last_hidden_state
            return self.ln(self.proj(out)), enc["attention_mask"].to(device, dtype=torch.int32)

    from rdkit import Chem
    ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
    def _one_hot(v, C): z=[0]*len(C); z[C.index(v) if v in C else -1]=1; return z
    def _bucket(v, lo, hi):
        buckets=list(range(lo,hi+1)); o=[0]*(len(buckets)+1); i=v-lo; o[i if 0<=i<len(buckets) else -1]=1; return o
    def _atom_feat(a):
        hyb=[Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
        chir=[Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
        f=_one_hot(a.GetSymbol(), ATOM_LIST+["other"])
        f+=_bucket(a.GetDegree(),0,5)+_bucket(a.GetFormalCharge(),-2,2)+(_one_hot(a.GetHybridization(),hyb)+[0])
        f+=[int(a.GetIsAromatic()), int(a.IsInRing())]+_one_hot(a.GetChiralTag(),chir)
        f+=_bucket(a.GetTotalNumHs(True),0,4)+_bucket(a.GetTotalValence(),0,5)+[a.GetMass()/200.0]
        return f
    def _smiles_to_graph(smi, max_nodes=128):
        m = Chem.MolFromSmiles(smi)
        if m is None or m.GetNumAtoms()==0: return np.zeros((0,0),np.float32), np.zeros((0,0),np.float32)
        X = np.asarray([_atom_feat(m.GetAtomWithIdx(i)) for i in range(m.GetNumAtoms())], np.float32)
        N = m.GetNumAtoms(); A = np.zeros((N,N),np.float32)
        for b in m.GetBonds(): i,j=b.GetBeginAtomIdx(), b.GetEndAtomIdx(); A[i,j]=A[j,i]=1.0
        if N>max_nodes: X=X[:max_nodes]; A=A[:max_nodes,:max_nodes]
        return X,A
    def _collate(smis, max_nodes=128):
        G=[_smiles_to_graph(s) for s in smis]
        Nmax=max([g[0].shape[0] for g in G]+[1]); F=G[0][0].shape[1] if G[0][0].size>0 else 51; B=len(G)
        X=np.zeros((B,Nmax,F),np.float32); A=np.zeros((B,Nmax,Nmax),np.float32); M=np.zeros((B,Nmax),np.int64)
        for i,(x,a) in enumerate(G):
            n=x.shape[0]
            if n==0: continue
            X[i,:n,:]=x; A[i,:n,:n]=a; M[i,:n]=1
        return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

    class GINLayer(nn.Module):
        def __init__(self,h=256,p=0.1): super().__init__(); self.eps=nn.Parameter(torch.tensor(0.0)); self.mlp=nn.Sequential(nn.Linear(h,h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
        def forward(self,x,A,M): return self.mlp((1.0+self.eps)*x + torch.matmul(A,x)) * M.unsqueeze(-1).to(x.dtype)
    class GraphGINEncoder(nn.Module):
        def __init__(self,node_in_dim=51,h=256,L=4,p=0.1): super().__init__(); self.inp=nn.Sequential(nn.Linear(node_in_dim,h), nn.GELU(), nn.Dropout(p)); self.layers=nn.ModuleList([GINLayer(h,p) for _ in range(L)]); self.out_ln=nn.LayerNorm(h)
        def forward(self,smis,max_nodes=128):
            X,A,M=_collate(smis,max_nodes); h=self.inp(X)
            for L in self.layers: h=L(h,A,M)
            return self.out_ln(h), M.to(dtype=torch.int32)

    def masked_mean(x, m, dim): m=m.to(dtype=x.dtype, device=x.device); denom=m.sum(dim=dim,keepdim=True).clamp(min=1.0); return (x*m.unsqueeze(-1)).sum(dim=dim)/denom
    class CrossAttentionBlock(nn.Module):
        def __init__(self,dim=256,heads=4,p=0.1): super().__init__(); self.mha=nn.MultiheadAttention(dim, heads, dropout=p, batch_first=False); self.ln=nn.LayerNorm(dim); self.do=nn.Dropout(p)
        def forward(self,T,TM,G,GM): Q=T.transpose(0,1); K=G.transpose(0,1); V=G.transpose(0,1); kpm=(GM==0); A,_=self.mha(Q,K,V,key_padding_mask=kpm); return self.ln(T+self.do(A.transpose(0,1)))
    class DescriptorMLP(nn.Module):
        def __init__(self,inp,dim=256,p=0.1): super().__init__(); self.net=nn.Sequential(nn.Linear(inp,256), nn.GELU(), nn.Dropout(p), nn.Linear(256,dim), nn.GELU(), nn.Dropout(p))
        def forward(self,x): return self.net(x)
    class FusionClassifier(nn.Module):
        def __init__(self,dim=256,L=N_LABELS,p=0.1): super().__init__(); self.mlp=nn.Sequential(nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p), nn.Linear(dim*2, L))
        def forward(self,z): return self.mlp(z)
    class V7Shared(nn.Module):
        def __init__(self):
            super().__init__()
            self.text=ChemBERTaEncoder()
            self.graph=GraphGINEncoder()
            self.cross=CrossAttentionBlock()
            self.desc=DescriptorMLP(DESC_IN_DIM)
            self.head=FusionClassifier()
        def forward(self,smis,desc,return_fused=False):
            T,TM=self.text(smis,256); G,GM=self.graph(smis,128); Ta=self.cross(T,TM,G,GM); Dz=self.desc(desc)
            Tp=masked_mean(Ta,TM,1); Gp=masked_mean(G,GM,1); fused=torch.cat([Tp,Gp,Dz],-1)
            logits=self.head(fused)
            return (logits, fused) if return_fused else logits

    # Build model & fused features
    v7 = V7Shared().to(device)
    v7.load_state_dict(torch.load(MODEL_DIR / "checkpoints" / "shared" / "best.pt", map_location=device)["model"], strict=True)
    v7.eval()

    # descriptors → impute → scale
    X_raw = rdkit_feats(smi)
    X_imp = imputer.transform(X_raw); X_std = scaler.transform(X_imp)
    desc_t = torch.tensor(X_std, dtype=torch.float32, device=device)

    fused_chunks=[]
    B=64
    for i in range(0, N, B):
        _, f = v7(smi[i:i+B], desc_t[i:i+B], return_fused=True)
        fused_chunks.append(f.detach().cpu().numpy())
    fused = np.concatenate(fused_chunks, 0).astype(np.float32)
    np.save(out, fused)
    print(f"[Rebuild] Saved → {out}")
    return fused

# Get VAL fused and labels
Xva = ensure_fused("val")
val_blob = np.load(PREP_DIR / "val.npz", allow_pickle=True)
Yva = val_blob["Y"].astype(np.float32)                      # (Nv, L)
Mva = val_blob["y_missing_mask"].astype(bool)               # (Nv, L)

# ------- Heads: specialist + shared (MLP on fused) -------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _remap_keys_if_needed(state_dict: dict) -> dict:
    needs = any(k.startswith(("b1.", "b2.", "b3.")) for k in state_dict.keys())
    if not needs: return state_dict
    remap = {}
    for k, v in state_dict.items():
        k2 = k.replace("b1.", "block1.").replace("b2.", "block2.").replace("b3.", "block3.")
        remap[k2] = v
    return remap

def load_best_head(label: str) -> nn.Module:
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/"), key=lambda p: p.name):
        mfile = sd / "metrics.json"
        if not mfile.exists(): continue
        try:
            ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
            cands.append((ap, sd))
        except Exception:
            pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label} under {ENS_DIR/label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    state = _remap_keys_if_needed(ck["model"])
    head.load_state_dict(state, strict=True)
    head.eval()
    return head

HEADS = {lbl: load_best_head(lbl) for lbl in LABELS}

class SharedHeadMLP(nn.Module):
    def __init__(self, dim=256, n_labels=N_LABELS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, z): return self.mlp(z)

sh = SharedHeadMLP().to(device)
ckpt = torch.load(MODEL_DIR / "checkpoints" / "shared" / "best.pt", map_location=device)
sh_state = {k.replace("shared_head.", ""): v for k,v in ckpt["model"].items() if k.startswith("shared_head.")}
sh.load_state_dict(sh_state, strict=True)
sh.eval()

# ------- Build blended probabilities on VAL -------
sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
Xva_t = torch.tensor(Xva, dtype=torch.float32, device=device)
with torch.no_grad():
    spec_logits_val = torch.stack([HEADS[l](Xva_t) for l in LABELS], dim=1).cpu().numpy()
    sh_logits_val   = sh(Xva_t).cpu().numpy()

P_spec_val = np.zeros_like(spec_logits_val, np.float32)
P_sh_val   = np.zeros_like(sh_logits_val,   np.float32)
P_blend_val= np.zeros_like(sh_logits_val,   np.float32)

for j, lbl in enumerate(LABELS):
    Ts = max(float(temps_spec.get(lbl, 1.0)), 1e-3)
    Th = max(float(temps_shared.get(lbl, 1.0)), 1e-3)
    ps = sigmoid(spec_logits_val[:, j] / Ts)
    ph = sigmoid(sh_logits_val[:, j]   / Th)
    P_spec_val[:, j] = ps
    P_sh_val[:, j]   = ph
    P_blend_val[:, j]= np.clip(ALPHA*ps + (1-ALPHA)*ph, 0.0, 1.0)

# ------- Helper: precision-floor threshold on VAL -------
def precision_floor_threshold(y_true: np.ndarray, probs: np.ndarray, floor: float) -> Tuple[float, float, float]:
    """
    Returns (threshold, precision_at_th, recall_at_th) selecting the *highest recall*
    point on the PR curve with precision >= floor. Falls back to best F1 if none.
    """
    prec, rec, th = precision_recall_curve(y_true, probs)
    # Align thresholds with prec/rec arrays (sklearn: th has len-1)
    th_aligned = np.concatenate([th, [1.0]]) if th.size > 0 else np.array([0.5])
    mask = prec >= floor
    if mask.any():
        # Among points with precision>=floor, pick the one with highest recall
        idx = np.argmax(rec[mask])
        # position within masked array → original index
        candidates = np.where(mask)[0]
        i = candidates[idx]
        return float(th_aligned[i]), float(prec[i]), float(rec[i])
    # Fallback: best F1
    eps = 1e-8
    f1 = (2*prec*rec)/np.maximum(prec+rec, eps)
    i = int(np.nanargmax(f1))
    return float(th_aligned[i]), float(prec[i]), float(rec[i])

# ------- Build policy per label -------
rows = []
policy = {
    "alpha": ALPHA,
    "precision_floor": PREC_FLOOR,
    "default_mode": "fbeta15",
    "labels": {}
}

# Map test metrics into a dict for easy access
test_by_label = {r["label"]: r for _, r in test_df.iterrows()}

for j, lbl in enumerate(LABELS):
    # Use only VAL rows with non-missing labels
    valid = ~Mva[:, j]
    yv = Yva[valid, j].astype(int)
    pv = P_blend_val[valid, j]
    # Degenerate label on val?
    degenerate = (yv.max() == yv.min())
    # Test metrics (to judge FP-ness)
    td = test_by_label.get(lbl, {})
    prec_f1  = float(td.get("prec@f1", np.nan))
    rec_f1   = float(td.get("recall@f1", np.nan))
    f1_test  = float(td.get("f1", np.nan))
    prec_fb  = float(td.get("prec@fbeta15", np.nan))
    rec_fb   = float(td.get("recall@fbeta15", np.nan))
    fbeta_t  = float(td.get("f_beta15", np.nan))

    # Start with default: keep fbeta15 threshold from blend calibration
    th_f1  = float(thr_blend[lbl]["th_f1"])
    th_fb  = float(thr_blend[lbl]["th_fbeta15"])
    decision = "fbeta15"
    chosen_th = th_fb
    reason = "default_fbeta15"

    # Heuristic 1: if F1 better than Fβ on test *and* precision improves, prefer F1
    if (not math.isnan(f1_test) and not math.isnan(fbeta_t)) and (f1_test >= fbeta_t) and (not math.isnan(prec_f1) and not math.isnan(prec_fb)) and (prec_f1 > prec_fb):
        decision = "f1"
        chosen_th = th_f1
        reason = "test_f1_better_and_more_precise"

    # Heuristic 2: if precision at Fβ is below floor, try precision-floor on VAL
    if (not math.isnan(prec_fb)) and (prec_fb < PREC_FLOOR) and (not degenerate):
        th_pf, pr_pf, rc_pf = precision_floor_threshold(yv, pv, PREC_FLOOR)
        # Only adopt if recall doesn't collapse completely
        if rc_pf >= max(0.5*rec_fb, 0.10):  # keep at least 50% of recall@Fβ, or ≥0.10 absolute
            decision = "precision_floor"
            chosen_th = th_pf
            reason = f"val_precision_floor_{PREC_FLOOR:.2f}"

    policy["labels"][lbl] = {
        "mode": decision,            # "fbeta15" | "f1" | "precision_floor"
        "threshold": float(chosen_th),
        "diag": {
            "test_prec@f1": prec_f1, "test_rec@f1": rec_f1, "test_f1": f1_test,
            "test_prec@fb": prec_fb, "test_rec@fb": rec_fb, "test_fbeta15": fbeta_t
        }
    }

    rows.append({
        "label": lbl,
        "decision": decision,
        "chosen_threshold": chosen_th,
        "reason": reason,
        "test_prec@f1": prec_f1, "test_rec@f1": rec_f1, "test_f1": f1_test,
        "test_prec@fb": prec_fb, "test_rec@fb": rec_fb, "test_fbeta15": fbeta_t
    })

# Save artifacts
policy_path = POL_DIR / "policy.json"
policy_path.write_text(json.dumps(policy, indent=2))
pd.DataFrame(rows).to_csv(EVAL_DIR / "policy_table.csv", index=False)

# Console summary
dec_counts = pd.Series([r["decision"] for r in rows]).value_counts().to_dict()
print("✅ Per-label policy proposed and saved.")
print("  • Policy JSON →", policy_path)
print("  • Table CSV   →", EVAL_DIR / "policy_table.csv")
print("Decision counts:", dec_counts)
print("\nExamples:")
for r in rows[:4]:
    print(f"  {r['label']}: {r['decision']} (th={r['chosen_threshold']:.3f}) ← {r['reason']}")

✅ Per-label policy proposed and saved.
  • Policy JSON → v7\model\policy\policy.json
  • Table CSV   → v7\eval\policy_table.csv
Decision counts: {'fbeta15': 5, 'precision_floor': 4, 'f1': 3}

Examples:
  NR-AR: precision_floor (th=0.653) ← val_precision_floor_0.55
  NR-AR-LBD: precision_floor (th=0.723) ← val_precision_floor_0.55
  NR-AhR: precision_floor (th=0.709) ← val_precision_floor_0.55
  NR-Aromatase: fbeta15 (th=0.474) ← default_fbeta15


## phase 7 (extra squeezes and testing)

### New reg to comapre f1, fbeta15 and the new policy (6,3)

In [None]:
# === V7 Policy-aware Single-SMILES/SMARTS Test Rig (self-contained) ===
# Loads:
#   v7/model/checkpoints/shared/best.pt
#   v7/model/ensembles/<label>/seed*/best.pt
#   v7/model/calibration/temps.json, temps_shared.json, thresholds_blend.json
#   v7/model/policy/policy.json   (optional; for mode="policy")

import json, math
from pathlib import Path
from typing import List, Dict
import numpy as np
import torch
import torch.nn as nn
from rdkit import Chem

BASE       = Path("v7")
PREP_DIR   = BASE / "data" / "prepared"
MODEL_DIR  = BASE / "model"
ENS_DIR    = MODEL_DIR / "ensembles"
CAL_DIR    = MODEL_DIR / "calibration"
POL_DIR    = MODEL_DIR / "policy"
CKPT_BEST  = MODEL_DIR / "checkpoints" / "shared" / "best.pt"

assert CKPT_BEST.exists(), f"Missing shared checkpoint: {CKPT_BEST}"
assert (PREP_DIR / "dataset_manifest.json").exists(), "Missing dataset manifest."

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Labels & dims ---
ds_manifest = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABEL_NAMES: List[str] = ds_manifest["labels"]
DESC_IN_DIM = int(ds_manifest["n_features"])  # 208

# --- Calibration + policy ---
temps_spec   = json.loads((CAL_DIR / "temps.json").read_text())            # specialist temps
temps_shared = json.loads((CAL_DIR / "temps_shared.json").read_text())     # shared temps
blend_payload= json.loads((CAL_DIR / "thresholds_blend.json").read_text()) # alpha + thresholds
ALPHA        = float(blend_payload.get("alpha", 0.8))
thr_blend    = blend_payload["thresholds"]                                  # per label
POL_PATH     = POL_DIR / "policy.json"
policy       = json.loads(POL_PATH.read_text()) if POL_PATH.exists() else None

# --- Text encoder (ChemBERTa) ---
from transformers import AutoTokenizer, AutoModel
class ChemBERTaEncoder(nn.Module):
    def __init__(self, ckpt_name="seyonec/ChemBERTa-zinc-base-v1", fusion_dim=256, dropout_p=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
        self.backbone  = AutoModel.from_pretrained(ckpt_name)
        self.proj = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(self.backbone.config.hidden_size, fusion_dim))
        self.ln   = nn.LayerNorm(fusion_dim)
    def forward(self, smiles_list: List[str], max_length=256, add_special_tokens=True):
        enc = self.tokenizer(list(smiles_list), padding=True, truncation=True,
                             max_length=max_length, add_special_tokens=add_special_tokens,
                             return_tensors="pt")
        input_ids, attention_mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
        out  = self.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state  # (B,L,H)
        toks = self.ln(self.proj(out))  # (B,L,256)
        return toks, attention_mask.to(dtype=torch.int32)

# --- Graph encoder (names aligned to checkpoint) ---
ATOM_LIST = ["H","C","N","O","F","P","S","Cl","Br","I"]
def _one_hot(v, choices):
    z = [0]*len(choices); z[choices.index(v) if v in choices else -1] = 1; return z
def _bucket_oh(v, lo, hi):
    buckets = list(range(lo, hi+1)); o = [0]*(len(buckets)+1); idx = v - lo
    o[idx if 0<=idx<len(buckets) else -1] = 1; return o
def _atom_feat(atom):
    hybs = [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
    chir = [Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER]
    feat  = _one_hot(atom.GetSymbol(), ATOM_LIST+["other"])
    feat += _bucket_oh(atom.GetDegree(), 0, 5)
    feat += _bucket_oh(atom.GetFormalCharge(), -2, 2)
    feat += (_one_hot(atom.GetHybridization(), hybs)+[0])
    feat += [int(atom.GetIsAromatic()), int(atom.IsInRing())]
    feat += _one_hot(atom.GetChiralTag(), chir)
    feat += _bucket_oh(atom.GetTotalNumHs(includeNeighbors=True), 0, 4)
    feat += _bucket_oh(atom.GetTotalValence(), 0, 5)
    feat += [atom.GetMass()/200.0]
    return feat  # ~51 dims
def _smiles_to_graph(smi, max_nodes=128):
    mol = Chem.MolFromSmiles(smi)
    if mol is None or mol.GetNumAtoms()==0: return np.zeros((0,0),np.float32), np.zeros((0,0),np.float32)
    X = np.asarray([_atom_feat(mol.GetAtomWithIdx(i)) for i in range(mol.GetNumAtoms())], np.float32)
    N = mol.GetNumAtoms(); A = np.zeros((N,N),np.float32)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        A[i,j] = 1.0; A[j,i] = 1.0
    if N>max_nodes: X=X[:max_nodes]; A=A[:max_nodes,:max_nodes]
    return X, A
def _collate_graphs(smiles_batch, max_nodes=128):
    graphs = [_smiles_to_graph(s) for s in smiles_batch]
    Nmax = max([g[0].shape[0] for g in graphs] + [1])
    Fnode = graphs[0][0].shape[1] if graphs[0][0].size>0 else 51
    B = len(graphs)
    X = np.zeros((B, Nmax, Fnode), np.float32)
    A = np.zeros((B, Nmax, Nmax), np.float32)
    M = np.zeros((B, Nmax), np.int64)
    for i, (x, a) in enumerate(graphs):
        n = x.shape[0]
        if n==0: continue
        X[i,:n,:] = x; A[i,:n,:n] = a; M[i,:n] = 1
    return torch.from_numpy(X).to(device), torch.from_numpy(A).to(device), torch.from_numpy(M).to(device)

class GINLayer(nn.Module):
    def __init__(self, h=256, p=0.1):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(0.0))
        self.mlp = nn.Sequential(nn.Linear(h, h), nn.GELU(), nn.LayerNorm(h), nn.Dropout(p))
    def forward(self, x, adj, mask):
        out = (1.0 + self.eps) * x + torch.matmul(adj, x)
        out = self.mlp(out)
        return out * mask.unsqueeze(-1).to(out.dtype)

class GraphGINEncoder(nn.Module):
    def __init__(self, node_in_dim=51, hidden_dim=256, n_layers=4, p=0.1):
        super().__init__()
        self.inp = nn.Sequential(nn.Linear(node_in_dim, hidden_dim), nn.GELU(), nn.Dropout(p))
        self.layers = nn.ModuleList([GINLayer(hidden_dim, p) for _ in range(n_layers)])
        self.out_ln = nn.LayerNorm(hidden_dim)  # name matches checkpoint
    def forward(self, smiles_list: List[str], max_nodes=128):
        X, A, M = _collate_graphs(smiles_list, max_nodes=max_nodes)
        h = self.inp(X)
        for layer in self.layers:
            h = layer(h, A, M)
        return self.out_ln(h), M.to(dtype=torch.int32)

# --- Fusion & shared head ---
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
    mask = mask.to(dtype=x.dtype, device=x.device)
    denom = mask.sum(dim=dim, keepdim=True).clamp(min=1.0)
    return (x * mask.unsqueeze(-1)).sum(dim=dim) / denom

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=256, n_heads=4, p=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, dropout=p, batch_first=False)
        self.ln  = nn.LayerNorm(dim)
        self.do  = nn.Dropout(p)
    def forward(self, text_tokens, text_mask, graph_nodes, graph_mask):
        Q = text_tokens.transpose(0,1)   # (L,B,D)
        K = graph_nodes.transpose(0,1)   # (N,B,D)
        V = graph_nodes.transpose(0,1)
        kpm = (graph_mask == 0)          # (B,N) 1=pad
        attn, _ = self.mha(Q, K, V, key_padding_mask=kpm)
        attn = attn.transpose(0,1)       # (B,L,D)
        return self.ln(text_tokens + self.do(attn))

class DescriptorMLP(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden=256, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(p),
            nn.Linear(hidden, out_dim), nn.GELU(), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)

class FusionClassifier(nn.Module):
    # name 'mlp' matches checkpoint ('shared_head.mlp.*')
    def __init__(self, dim=256, n_labels=12, p=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(p),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, fused_vec): return self.mlp(fused_vec)

class V7FusionModel(nn.Module):
    def __init__(self, text_encoder, graph_encoder, desc_in_dim=208, dim=256, n_labels=12, n_heads=4, p=0.1):
        super().__init__()
        self.text_encoder=text_encoder
        self.graph_encoder=graph_encoder
        self.cross=CrossAttentionBlock(dim, n_heads, p)
        self.desc_mlp=DescriptorMLP(desc_in_dim, out_dim=dim, hidden=256, p=p)
        self.shared_head=FusionClassifier(dim, n_labels, p)
    def forward(self, smiles_list, desc_feats):
        tt, tm = self.text_encoder(smiles_list, max_length=256)
        gn, gm = self.graph_encoder(smiles_list, max_nodes=128)
        tta = self.cross(tt.to(device), tm.to(device), gn.to(device), gm.to(device))
        de  = self.desc_mlp(desc_feats.to(device))
        text_pool  = masked_mean(tta, tm.to(device), 1)
        graph_pool = masked_mean(gn.to(device),  gm.to(device), 1)
        fused = torch.cat([text_pool, graph_pool, de], dim=-1)  # (B,768)
        logits = self.shared_head(fused)
        return logits, fused

# Build shared model
text_encoder = ChemBERTaEncoder().to(device)
graph_encoder= GraphGINEncoder().to(device)
v7_shared    = V7FusionModel(text_encoder, graph_encoder, desc_in_dim=DESC_IN_DIM, n_labels=len(LABEL_NAMES)).to(device)
ckpt = torch.load(CKPT_BEST, map_location=device)
v7_shared.load_state_dict(ckpt["model"], strict=True)
v7_shared.eval()

# --- Specialist heads (robust loader) ---
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1 = nn.Sequential(nn.Linear(in_dim, h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2 = nn.Sequential(nn.Linear(h1, h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3 = nn.Sequential(nn.Linear(h2, h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out    = nn.Linear(h3, 1)
        self.short  = nn.Linear(in_dim, h3)
    def forward(self, x):
        z1 = self.block1(x); z2 = self.block2(z1); z3 = self.block3(z2)
        z  = z3 + self.short(x)
        return self.out(z).squeeze(-1)

def _remap_keys_if_needed(state_dict: dict) -> dict:
    if not any(k.startswith(("b1.","b2.","b3.")) for k in state_dict.keys()):
        return state_dict
    remap = {}
    for k, v in state_dict.items():
        k2 = k.replace("b1.", "block1.").replace("b2.", "block2.").replace("b3.", "block3.")
        remap[k2] = v
    return remap

def _load_best_head(label: str) -> nn.Module:
    cands = []
    for sd in sorted((ENS_DIR / label).glob("seed*/")):
        mfile = sd / "metrics.json"
        if mfile.exists():
            try:
                ap = float(json.loads(mfile.read_text()).get("best_ap", float("nan")))
                cands.append((ap, sd))
            except Exception:
                pass
    if not cands:
        raise FileNotFoundError(f"No trained heads for label {label} under {ENS_DIR/label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best_dir = cands[0][1]
    ck = torch.load(best_dir / "best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(in_dim=cfg["in_dim"], h1=cfg["h1"], h2=cfg["h2"], h3=cfg["h3"], p=cfg.get("dropout",0.30)).to(device)
    state = _remap_keys_if_needed(ck["model"])
    head.load_state_dict(state, strict=True)
    head.eval()
    return head

HEADS: Dict[str, nn.Module] = {lbl: _load_best_head(lbl) for lbl in LABEL_NAMES}

# --- Descriptor prep for ad-hoc inputs: standardized zeros (robust & simple) ---
def prepare_desc_matrix(smiles_list: List[str]) -> torch.Tensor:
    Z = np.zeros((len(smiles_list), DESC_IN_DIM), dtype=np.float32)
    return torch.tensor(Z, dtype=torch.float32, device=device)

# --- Normalize SMARTS→SMILES if needed ---
def normalize_smiles_or_smarts(s: str) -> str:
    if not isinstance(s, str): s = str(s)
    mol = Chem.MolFromSmiles(s)
    if mol: return Chem.MolToSmiles(mol)
    q = Chem.MolFromSmarts(s)
    if q:
        try:
            smi = Chem.MolToSmiles(q)
            return smi if smi else s
        except Exception:
            return s
    return s

@torch.no_grad()
def fused_from_smiles(smiles_list: List[str]) -> torch.Tensor:
    smiles_list = [normalize_smiles_or_smarts(s) for s in smiles_list]
    desc = prepare_desc_matrix(smiles_list)
    _, fused = v7_shared(smiles_list, desc)
    return fused  # (B,768)

# --- Inference helpers ---
sigmoid = lambda x: 1.0/(1.0+math.e**(-x))

@torch.no_grad()
def _probs_for_one(smi: str) -> Dict[str, Dict[str, float]]:
    """Return per-label {prob_spec, prob_shared, prob_blend} for one SMILES using alpha & temps."""
    x = fused_from_smiles([smi])  # (1,768)
    logits_shared = v7_shared.shared_head(x).detach().cpu().numpy()[0]  # (L,)
    rec = {}
    for j, lbl in enumerate(LABEL_NAMES):
        logit_spec = HEADS[lbl](x).item()
        p_spec   = sigmoid(logit_spec / max(float(temps_spec.get(lbl, 1.0)), 1e-3))
        p_shared = sigmoid(float(logits_shared[j]) / max(float(temps_shared.get(lbl, 1.0)), 1e-3))
        p_blend  = ALPHA * p_spec + (1.0 - ALPHA) * p_shared
        rec[lbl] = {"prob_spec": float(p_spec), "prob_shared": float(p_shared), "prob_blend": float(p_blend)}
    return rec

def _threshold_for(lbl: str, mode: str) -> float:
    """Return threshold for a label under 'f1', 'fbeta15', or 'policy'."""
    if mode in ("f1", "fbeta15"):
        key = "th_f1" if mode == "f1" else "th_fbeta15"
        return float(thr_blend[lbl][key])
    elif mode == "policy":
        if policy is None:
            raise RuntimeError("policy.json not found; run Phase 6 — Cell 3 to create it.")
        return float(policy["labels"][lbl]["threshold"])
    else:
        raise ValueError("mode must be 'f1', 'fbeta15', or 'policy'")

def predict_one(smi: str, mode: str = "fbeta15", topk: int = 5, show_parts: bool = False):
    """
    mode: 'f1' | 'fbeta15' | 'policy'
    Prints top-k by blended prob and the positive set at chosen thresholds.
    If show_parts=True, also shows p_spec / p_shared next to p_blend.
    """
    assert mode in ("f1","fbeta15","policy")
    rec = _probs_for_one(smi)
    print(f"\nSMILES/SMARTS: {smi}\n mode={mode}, alpha={ALPHA:.2f}")
    rows = []
    for lbl, d in rec.items():
        th = _threshold_for(lbl, mode)
        rows.append((lbl, d["prob_blend"], th, d["prob_blend"] >= th, d["prob_spec"], d["prob_shared"]))
    rows.sort(key=lambda z: z[1], reverse=True)
    for lbl, p, th, dec, ps, ph in rows[:topk]:
        if show_parts:
            print(f"  {lbl:12s}  p_spec={ps:.3f}  p_shared={ph:.3f}  p_blend={p:.3f}  th={th:.3f}  → pred={int(dec)}")
        else:
            print(f"  {lbl:12s}  p_blend={p:.3f}  th={th:.3f}  → pred={int(dec)}")
    positives = [lbl for lbl, p, th, dec, *_ in rows if dec]
    print("  Positives:", ", ".join(positives) if positives else "none")
    return {lbl: {"prob_spec": float(ps), "prob_shared": float(ph), "prob_blend": float(p),
                  "threshold": float(th), "decision": bool(dec)}
            for (lbl, p, th, dec, ps, ph) in rows}

def compare_modes(smi: str, modes: List[str] = ("fbeta15","f1","policy"), topk: int = 5):
    """Side-by-side comparison for the same SMILES."""
    print("="*72)
    for m in modes:
        predict_one(smi, mode=m, topk=topk, show_parts=False)
        print("-"*72)

print("✅ Policy-aware tester ready.")

✅ Policy-aware tester ready.


#### test on the csv

In [None]:
# === Batch compare modes ('fbeta15', 'f1', 'policy') on Truth Lables.xlsx — FIXED NAN HANDLING ===
import json, math
from pathlib import Path
from typing import List, Dict
import numpy as np
import pandas as pd
import torch

# ---- prerequisites from your test rig ----
need = ['v7_shared','HEADS','LABEL_NAMES','temps_spec','temps_shared','ALPHA','thr_blend','_probs_for_one','_threshold_for']
for n in need:
    assert n in globals(), f"Missing '{n}'. Please run the self-contained test rig cell first."

# ---- paths ----
BASE    = Path("v7")
DATA_XL = BASE / "data" / "Truth Lables.xlsx"
OUT_DIR = BASE / "results" / "inference"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---- load excel ----
assert DATA_XL.exists(), f"Missing file: {DATA_XL}"
df = pd.read_excel(DATA_XL)
df.columns = [c.strip() for c in df.columns]

# smiles column
smiles_col = next((c for c in ["smiles","SMILES","Smile","smile","SMILE"] if c in df.columns), None)
assert smiles_col is not None, f"Could not find a SMILES column in: {list(df.columns)}"

# label columns must match training label names exactly
LABELS = list(LABEL_NAMES)
label_cols = [c for c in df.columns if c in LABELS]
assert len(label_cols) == len(LABELS), \
    f"Expected 12 label columns matching training names.\nFound {len(label_cols)}: {label_cols}\nWanted: {LABELS}"

def _to01(x):
    if pd.isna(x): return np.nan
    if isinstance(x, str):
        xs = x.strip().lower()
        if xs in {"1","true","yes","y"}: return 1
        if xs in {"0","false","no","n"}: return 0
        try:
            xv = float(xs)
            if np.isnan(xv): return np.nan
            return 1 if xv >= 0.5 else 0
        except Exception:
            return np.nan
    try:
        xv = float(x)
        if np.isnan(xv): return np.nan
        return 1 if xv >= 0.5 else 0
    except Exception:
        return np.nan

MODES = ("fbeta15","f1","policy")

# ---- iterate molecules, compute predictions once, then threshold by mode ----
rows = []
for idx, r in df.iterrows():
    smi = str(r[smiles_col])
    probs = _probs_for_one(smi)  # {lbl: {prob_spec, prob_shared, prob_blend}}
    for lbl in LABELS:
        prob = float(probs[lbl]["prob_blend"])
        truth = _to01(r.get(lbl, np.nan))
        for mode in MODES:
            # skip policy if policy.json wasn’t created
            try:
                th = _threshold_for(lbl, mode)
            except Exception:
                if mode == "policy":
                    continue
                raise
            pred = int(prob >= th)
            rows.append({
                "row_id": int(idx),
                "smiles": smi,
                "label": lbl,
                "mode": mode,
                "prob_blend": prob,
                "threshold": float(th),
                "prediction": pred,
                "truth": (np.nan if pd.isna(truth) else int(truth))
            })

detailed = pd.DataFrame(rows)

# ---- metrics helpers (robust to empty/degenerate cases) ----
def _safe_div(n, d):
    return (n / d) if d > 0 else np.nan

def _prf(tp, fp, fn):
    prec = _safe_div(tp, tp+fp)
    rec  = _safe_div(tp, tp+fn)
    if np.isnan(prec) or np.isnan(rec) or (prec+rec) == 0:
        f1 = np.nan
    else:
        f1 = 2 * prec * rec / (prec + rec)
    return prec, rec, f1

def _metrics(df_long: pd.DataFrame, mode: str) -> Dict:
    dd = df_long[df_long["mode"] == mode].copy()
    dd = dd.dropna(subset=["truth"])  # drop rows with unknown truth
    if dd.empty:
        return {
            "mode": mode,
            "micro": {"precision": np.nan, "recall": np.nan, "f1": np.nan},
            "macro_f1": np.nan,
            "per_label": [{ "label": lbl, "precision": np.nan, "recall": np.nan, "f1": np.nan } for lbl in LABELS]
        }

    y = dd["truth"].astype(int).to_numpy()
    p = dd["prediction"].astype(int).to_numpy()
    tp = int(((p==1)&(y==1)).sum())
    fp = int(((p==1)&(y==0)).sum())
    fn = int(((p==0)&(y==1)).sum())
    m_prec, m_rec, m_f1 = _prf(tp, fp, fn)

    per_label = []
    for lbl in LABELS:
        d = dd[dd["label"] == lbl]
        if d.empty:
            per_label.append({"label": lbl, "precision": np.nan, "recall": np.nan, "f1": np.nan})
            continue
        yj = d["truth"].astype(int).to_numpy()
        pj = d["prediction"].astype(int).to_numpy()
        tpj = int(((pj==1)&(yj==1)).sum())
        fpj = int(((pj==1)&(yj==0)).sum())
        fnj = int(((pj==0)&(yj==1)).sum())
        prec, rec, f1 = _prf(tpj, fpj, fnj)
        per_label.append({
            "label": lbl,
            "precision": float(prec) if not np.isnan(prec) else np.nan,
            "recall":    float(rec)  if not np.isnan(rec)  else np.nan,
            "f1":        float(f1)   if not np.isnan(f1)   else np.nan
        })

    # macro F1 across labels (ignore NaNs)
    macro_f1 = float(np.nanmean([pl["f1"] for pl in per_label])) if len(per_label) else np.nan

    return {
        "mode": mode,
        "micro": {"precision": float(m_prec) if not np.isnan(m_prec) else np.nan,
                  "recall":    float(m_rec)  if not np.isnan(m_rec)  else np.nan,
                  "f1":        float(m_f1)   if not np.isnan(m_f1)   else np.nan},
        "macro_f1": macro_f1,
        "per_label": per_label
    }

# Only compute 'policy' summary if any policy rows exist
modes_present = detailed["mode"].unique().tolist()
summary = {m: _metrics(detailed, m) for m in MODES if m in modes_present}

# ---- save outputs ----
csv_path  = OUT_DIR / "Truth_labels_detailed.csv"
json_path = OUT_DIR / "Truth_labels_summary.json"
detailed.to_csv(csv_path, index=False)
json_path.write_text(json.dumps(summary, indent=2))

print("✅ Saved:")
print("  • Detailed predictions →", csv_path)
print("  • Summary metrics      →", json_path)

# quick leaderboard
print("\n=== Micro metrics by mode ===")
for m, rep in summary.items():
    pr = rep["micro"]["precision"]; rc = rep["micro"]["recall"]; f1 = rep["micro"]["f1"]
    print(f"  {m:8s}  P={np.nan if pr is None else pr:.3f}  R={np.nan if rc is None else rc:.3f}  F1={np.nan if f1 is None else f1:.3f}")
print("\n=== Macro F1 by mode ===")
for m, rep in summary.items():
    print(f"  {m:8s}  macro-F1={rep['macro_f1']:.3f}")


✅ Saved:
  • Detailed predictions → v7\results\inference\Truth_labels_detailed.csv
  • Summary metrics      → v7\results\inference\Truth_labels_summary.json

=== Micro metrics by mode ===
  fbeta15   P=0.625  R=0.833  F1=0.714
  f1        P=0.750  R=0.750  F1=0.750
  policy    P=0.667  R=0.667  F1=0.667

=== Macro F1 by mode ===
  fbeta15   macro-F1=0.783
  f1        macro-F1=0.829
  policy    macro-F1=0.772


### Recompute the F1 Recompute F1 thresholds (VAL) and save as thresholds_blend_v2.json

In [None]:
# === Phase 6 — Cell 4a: Recompute F1 thresholds (VAL) and save as thresholds_blend_v2.json ===
import json, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import precision_recall_curve, average_precision_score

BASE      = Path("v7")
PREP_DIR  = BASE / "data" / "prepared"
FUSED_DIR = BASE / "data" / "fused"
MODEL_DIR = BASE / "model"
CAL_DIR   = MODEL_DIR / "calibration"
ENS_DIR   = MODEL_DIR / "ensembles"
CAL_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load manifest, temps, old thresholds (keep fb15 as-is) ---
mani = json.loads((PREP_DIR / "dataset_manifest.json").read_text())
LABELS = mani["labels"]; N_LABELS = len(LABELS)
DESC_IN_DIM = int(mani["n_features"])

temps_spec   = json.loads((CAL_DIR / "temps.json").read_text())
temps_shared = json.loads((CAL_DIR / "temps_shared.json").read_text())
blend_payload= json.loads((CAL_DIR / "thresholds_blend.json").read_text())
ALPHA        = float(blend_payload.get("alpha", 0.8))
thr_old      = blend_payload["thresholds"]  # dict[label]->{th_f1, th_fbeta15, ap_val}

# --- Ensure VAL fused features ---
Xva = np.load(FUSED_DIR / "val_fused.npy").astype(np.float32)
val_blob = np.load(PREP_DIR / "val.npz", allow_pickle=True)
Yva = val_blob["Y"].astype(np.float32)
Mva = val_blob["y_missing_mask"].astype(bool)
Xva_t = torch.tensor(Xva, dtype=torch.float32, device=device)

# --- Shared head MLP (for logits) ---
class SharedHeadMLP(nn.Module):
    def __init__(self, dim=256, n_labels=N_LABELS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, z): return self.mlp(z)

sh = SharedHeadMLP().to(device)
ckpt = torch.load(MODEL_DIR / "checkpoints" / "shared" / "best.pt", map_location=device)
sh_state = {k.replace("shared_head.", ""): v for k,v in ckpt["model"].items() if k.startswith("shared_head.")}
sh.load_state_dict(sh_state, strict=True)
sh.eval()

# --- Specialist heads loader (best seed per label) ---
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1=nn.Sequential(nn.Linear(in_dim,h1),nn.GELU(),nn.LayerNorm(h1),nn.Dropout(p))
        self.block2=nn.Sequential(nn.Linear(h1,h2),nn.GELU(),nn.LayerNorm(h2),nn.Dropout(p))
        self.block3=nn.Sequential(nn.Linear(h2,h3),nn.GELU(),nn.LayerNorm(h3),nn.Dropout(p))
        self.out=nn.Linear(h3,1); self.short=nn.Linear(in_dim,h3)
    def forward(self,x):
        z1=self.block1(x); z2=self.block2(z1); z3=self.block3(z2)
        return self.out(z3+self.short(x)).squeeze(-1)

def _remap_keys_if_needed(sd: dict)->dict:
    if not any(k.startswith(("b1.","b2.","b3.")) for k in sd.keys()): return sd
    out={}; 
    for k,v in sd.items():
        out[k.replace("b1.","block1.").replace("b2.","block2.").replace("b3.","block3.")] = v
    return out

def load_best_head(label: str) -> nn.Module:
    import math, json
    cands=[]
    for sd in sorted((ENS_DIR/label).glob("seed*/")):
        m=sd/"metrics.json"
        if m.exists():
            try: cands.append((float(json.loads(m.read_text()).get("best_ap", float("nan"))), sd))
            except: pass
    if not cands: raise FileNotFoundError(f"No heads for {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best = cands[0][1]
    ck = torch.load(best/"best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head=LabelHead(cfg["in_dim"],cfg["h1"],cfg["h2"],cfg["h3"],cfg.get("dropout",0.30)).to(device)
    state=_remap_keys_if_needed(ck["model"]); head.load_state_dict(state, strict=True); head.eval(); return head

HEADS = {lbl: load_best_head(lbl) for lbl in LABELS}

# --- Build blended probabilities on VAL ---
sigmoid = lambda x: 1/(1+np.exp(-x))
with torch.no_grad():
    spec_logits = torch.stack([HEADS[l](Xva_t) for l in LABELS], dim=1).cpu().numpy()  # (Nv,L)
    sh_logits   = sh(Xva_t).cpu().numpy()

P_blend = np.zeros_like(sh_logits, np.float32)
for j,lbl in enumerate(LABELS):
    Ts=max(float(temps_spec.get(lbl,1.0)),1e-3); Th=max(float(temps_shared.get(lbl,1.0)),1e-3)
    ps=sigmoid(spec_logits[:,j]/Ts); ph=sigmoid(sh_logits[:,j]/Th)
    P_blend[:,j]=np.clip(ALPHA*ps + (1-ALPHA)*ph, 0, 1)

# --- Recompute th_f1 per label (keep old th_fbeta15) ---
from sklearn.metrics import precision_recall_curve
def best_f1_threshold(y_true, probs):
    prec, rec, th = precision_recall_curve(y_true, probs)
    eps=1e-8; f1=(2*prec*rec)/np.maximum(prec+rec, eps)
    if th.size==0: return 0.5
    # f1 is len = len(th)+1; pick argmax ignoring the first point
    idx = int(np.nanargmax(f1[1:]))+1
    # align threshold array with f1: append 1.0 at end
    th_aligned = np.concatenate([th, [1.0]])
    return float(th_aligned[idx])

new = {"alpha": ALPHA, "thresholds": {}}
print("Recomputing F1 thresholds on VAL...")
for j,lbl in enumerate(LABELS):
    valid = ~Mva[:,j]
    if valid.sum()==0 or np.all(Yva[valid,j]==Yva[valid,j][0]):
        new["thresholds"][lbl] = {
            "th_f1": float(thr_old[lbl]["th_f1"]),
            "th_fbeta15": float(thr_old[lbl]["th_fbeta15"]),
            "ap_val": float(thr_old[lbl].get("ap_val", float("nan")))
        }
        continue
    th_f1_new = best_f1_threshold(Yva[valid,j].astype(int), P_blend[valid,j])
    ap = float(average_precision_score(Yva[valid,j].astype(int), P_blend[valid,j]))
    th_f1_old = float(thr_old[lbl]["th_f1"])
    th_fb = float(thr_old[lbl]["th_fbeta15"])
    print(f"  {lbl:12s}  old_f1={th_f1_old:.3f} → new_f1={th_f1_new:.3f}  (Δ={th_f1_new-th_f1_old:+.3f})  AP={ap:.3f}")
    new["thresholds"][lbl] = {"th_f1": th_f1_new, "th_fbeta15": th_fb, "ap_val": ap}

out_path = CAL_DIR / "thresholds_blend_v2.json"
out_path.write_text(json.dumps(new, indent=2))
print("\n✅ Saved updated thresholds →", out_path)

Recomputing F1 thresholds on VAL...
  NR-AR         old_f1=0.653 → new_f1=0.660  (Δ=+0.007)  AP=0.171
  NR-AR-LBD     old_f1=0.621 → new_f1=0.621  (Δ=+0.001)  AP=0.253
  NR-AhR        old_f1=0.709 → new_f1=0.709  (Δ=+0.001)  AP=0.524
  NR-Aromatase  old_f1=0.564 → new_f1=0.565  (Δ=+0.001)  AP=0.295
  NR-ER         old_f1=0.547 → new_f1=0.547  (Δ=+0.000)  AP=0.253
  NR-ER-LBD     old_f1=0.589 → new_f1=0.594  (Δ=+0.004)  AP=0.139
  NR-PPAR-gamma  old_f1=0.441 → new_f1=0.441  (Δ=+0.000)  AP=0.063
  SR-ARE        old_f1=0.528 → new_f1=0.528  (Δ=+0.001)  AP=0.344
  SR-ATAD5      old_f1=0.483 → new_f1=0.483  (Δ=+0.001)  AP=0.171
  SR-HSE        old_f1=0.472 → new_f1=0.472  (Δ=+0.000)  AP=0.196
  SR-MMP        old_f1=0.589 → new_f1=0.591  (Δ=+0.002)  AP=0.444
  SR-p53        old_f1=0.513 → new_f1=0.513  (Δ=+0.000)  AP=0.210

✅ Saved updated thresholds → v7\model\calibration\thresholds_blend_v2.json


#### Test of CSV 

In [None]:
# === Phase 6 — Cell 4b: Truth test with new f1_soft mode (uses thresholds_blend_v2.json if present) ===
import json, math
from pathlib import Path
from typing import List, Dict
import numpy as np
import pandas as pd
import torch

BASE     = Path("v7")
DATA_XL  = BASE / "data" / "Truth Lables.xlsx"
OUT_DIR  = BASE / "results" / "inference"
CAL_DIR  = BASE / "model" / "calibration"
POL_DIR  = BASE / "model" / "policy"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Prereqs from the self-contained rig
need = ['v7_shared','HEADS','LABEL_NAMES','temps_spec','temps_shared','ALPHA','_probs_for_one']
for n in need: assert n in globals(), f"Missing {n}; run the policy-aware test rig cell first."

# Load thresholds: prefer v2 if exists
tb_v2 = CAL_DIR / "thresholds_blend_v2.json"
tb_v1 = CAL_DIR / "thresholds_blend.json"
assert tb_v1.exists(), f"Missing {tb_v1}"
thr_payload = json.loads((tb_v2 if tb_v2.exists() else tb_v1).read_text())
thr = thr_payload["thresholds"]   # label -> {th_f1, th_fbeta15}
ALPHA = float(thr_payload.get("alpha", ALPHA))  # keep same alpha

# Optional policy
POL_PATH = POL_DIR / "policy.json"
policy = json.loads(POL_PATH.read_text()) if POL_PATH.exists() else None

# New soft band (delta)
DELTA = 0.04

def _threshold_for(lbl: str, mode: str, p_blend: float = None) -> float:
    """For 'f1_soft' we return the f1 threshold (decision rule adds softness outside)."""
    if mode == "f1":
        return float(thr[lbl]["th_f1"])
    if mode == "fbeta15":
        return float(thr[lbl]["th_fbeta15"])
    if mode == "policy":
        if policy is None: raise RuntimeError("policy.json not found")
        return float(policy["labels"][lbl]["threshold"])
    if mode == "f1_soft":
        return float(thr[lbl]["th_f1"])
    raise ValueError("mode must be one of: 'f1','fbeta15','policy','f1_soft'")

def _decide(lbl: str, mode: str, p: float) -> int:
    if mode in ("f1","fbeta15","policy"):
        th = _threshold_for(lbl, mode)
        return int(p >= th)
    if mode == "f1_soft":
        th_f1 = float(thr[lbl]["th_f1"])
        th_fb = float(thr[lbl]["th_fbeta15"])
        if p >= th_f1: 
            return 1
        if (p >= th_fb) and ((th_f1 - p) <= DELTA):
            return 1
        return 0

# --- Load excel truth set ---
assert DATA_XL.exists(), f"Missing {DATA_XL}"
df = pd.read_excel(DATA_XL)
df.columns = [c.strip() for c in df.columns]
smiles_col = next((c for c in ["smiles","SMILES","Smile","smile","SMILE"] if c in df.columns), None)
assert smiles_col is not None, f"No SMILES column found. Got: {list(df.columns)}"
LABELS = list(LABEL_NAMES)
label_cols = [c for c in df.columns if c in LABELS]
assert len(label_cols) == len(LABELS), f"Label columns mismatch. Found {label_cols}, expected {LABELS}"

def _to01(x):
    if pd.isna(x): return np.nan
    try:
        if isinstance(x,str):
            xs=x.strip().lower()
            if xs in {"1","true","yes","y"}: return 1
            if xs in {"0","false","no","n"}: return 0
            xv=float(xs); 
            if np.isnan(xv): return np.nan
            return 1 if xv>=0.5 else 0
        xv=float(x); 
        if np.isnan(xv): return np.nan
        return 1 if xv>=0.5 else 0
    except: 
        return np.nan

MODES = ["fbeta15","f1","f1_soft"]
if policy is not None: MODES.append("policy")

# --- Compute predictions ---
rows=[]
for idx, r in df.iterrows():
    smi = str(r[smiles_col])
    per_label = _probs_for_one(smi)  # {lbl:{prob_spec,prob_shared,prob_blend}}
    for lbl in LABELS:
        p = float(per_label[lbl]["prob_blend"])
        truth = _to01(r.get(lbl, np.nan))
        for m in MODES:
            th = _threshold_for(lbl, m, p)
            pred = _decide(lbl, m, p)
            rows.append({
                "row_id": int(idx),
                "smiles": smi,
                "label": lbl,
                "mode": m,
                "prob_blend": p,
                "threshold": float(th),
                "prediction": int(pred),
                "truth": (np.nan if pd.isna(truth) else int(truth))
            })
detailed = pd.DataFrame(rows)

# --- Metrics (robust) ---
def _safe_div(n,d): return (n/d) if d>0 else np.nan
def _prf(tp,fp,fn):
    prec=_safe_div(tp,tp+fp); rec=_safe_div(tp,tp+fn)
    f1=np.nan if (np.isnan(prec) or np.isnan(rec) or (prec+rec)==0) else (2*prec*rec/(prec+rec))
    return prec,rec,f1

def _metrics(df_long, mode):
    dd=df_long[df_long["mode"]==mode].dropna(subset=["truth"]).copy()
    if dd.empty:
        return {"mode":mode,"micro":{"precision":np.nan,"recall":np.nan,"f1":np.nan},"macro_f1":np.nan,"per_label":[]}
    y=dd["truth"].astype(int).to_numpy(); p=dd["prediction"].astype(int).to_numpy()
    tp=int(((p==1)&(y==1)).sum()); fp=int(((p==1)&(y==0)).sum()); fn=int(((p==0)&(y==1)).sum())
    M=_prf(tp,fp,fn)
    per=[]
    for lbl in LABELS:
        d=dd[dd["label"]==lbl]; 
        if d.empty: per.append({"label":lbl,"precision":np.nan,"recall":np.nan,"f1":np.nan}); continue
        yj=d["truth"].astype(int).to_numpy(); pj=d["prediction"].astype(int).to_numpy()
        tpj=int(((pj==1)&(yj==1)).sum()); fpj=int(((pj==1)&(yj==0)).sum()); fnj=int(((pj==0)&(yj==1)).sum())
        pr,rc,f1=_prf(tpj,fpj,fnj)
        per.append({"label":lbl,"precision":float(pr) if not np.isnan(pr) else np.nan,
                           "recall":float(rc) if not np.isnan(rc) else np.nan,
                           "f1":float(f1) if not np.isnan(f1) else np.nan})
    macro_f1=float(np.nanmean([pl["f1"] for pl in per])) if per else np.nan
    return {"mode":mode,"micro":{"precision":float(M[0]) if not np.isnan(M[0]) else np.nan,
                                 "recall":float(M[1]) if not np.isnan(M[1]) else np.nan,
                                 "f1":float(M[2]) if not np.isnan(M[2]) else np.nan},
            "macro_f1":macro_f1,"per_label":per}

summary = {m:_metrics(detailed,m) for m in MODES}

# --- Save v2 outputs ---
csv_path  = OUT_DIR / "Truth_labels_detailed_v2.csv"
json_path = OUT_DIR / "Truth_labels_summary_v2.json"
detailed.to_csv(csv_path, index=False)
json_path.write_text(json.dumps(summary, indent=2))

print("✅ Saved:")
print("  • Detailed predictions →", csv_path)
print("  • Summary metrics      →", json_path)
print(f"  • f1_soft delta = {DELTA:.3f}")

print("\n=== Micro metrics by mode ===")
for m, rep in summary.items():
    pr, rc, f1 = rep["micro"]["precision"], rep["micro"]["recall"], rep["micro"]["f1"]
    print(f"  {m:8s}  P={np.nan if pr is None else pr:.3f}  R={np.nan if rc is None else rc:.3f}  F1={np.nan if f1 is None else f1:.3f}")
print("\n=== Macro F1 by mode ===")
for m, rep in summary.items():
    print(f"  {m:8s}  macro-F1={rep['macro_f1']:.3f}")

✅ Saved:
  • Detailed predictions → v7\results\inference\Truth_labels_detailed_v2.csv
  • Summary metrics      → v7\results\inference\Truth_labels_summary_v2.json
  • f1_soft delta = 0.040

=== Micro metrics by mode ===
  fbeta15   P=0.703  R=0.818  F1=0.756
  f1        P=0.741  R=0.727  F1=0.734
  f1_soft   P=0.705  R=0.782  F1=0.741
  policy    P=0.691  R=0.691  F1=0.691

=== Macro F1 by mode ===
  fbeta15   macro-F1=0.705
  f1        macro-F1=0.760
  f1_soft   macro-F1=0.703
  policy    macro-F1=0.739


### Per-label stacking combiner (VAL), calibrate & threshold

#### A)

In [1]:
# === Cell A: Per-label stacking on VAL (logistic combiner of [spec_logit, shared_logit])
# Saves under v7_exp/stacking_asl_v1/* without touching v7/*
import json, math, os
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve, average_precision_score
from joblib import dump as joblib_dump

# ---------- Paths (new experiment root) ----------
ROOT_OLD = Path("v7")
ROOT_NEW = Path("v7_exp") / "stacking_asl_v1"
ROOT_NEW.mkdir(parents=True, exist_ok=True)

MODEL_OLD = ROOT_OLD / "model"
CAL_OLD   = MODEL_OLD / "calibration"
ENS_OLD   = MODEL_OLD / "ensembles"
FUSED     = ROOT_OLD / "data" / "fused"
PREP      = ROOT_OLD / "data" / "prepared"

MODEL_NEW = ROOT_NEW / "model"
STACK_DIR = MODEL_NEW / "stacking_lr"
CAL_NEW   = ROOT_NEW / "calibration"
EVAL_NEW  = ROOT_NEW / "eval"
for p in [STACK_DIR, CAL_NEW, EVAL_NEW]:
    p.mkdir(parents=True, exist_ok=True)

# ---------- Load manifests / artifacts ----------
mani = json.loads((PREP / "dataset_manifest.json").read_text())
LABELS: List[str] = mani["labels"]; N_LABELS = len(LABELS)

# fused VAL features (B,768)
Xv = np.load(FUSED / "val_fused.npy").astype(np.float32)
val_blob = np.load(PREP / "val.npz", allow_pickle=True)
Yv = val_blob["Y"].astype(np.float32)                # (Nv, L)
Mv = val_blob["y_missing_mask"].astype(bool)         # True where missing

# specialist temperatures and shared temperatures
temps_spec   = json.loads((CAL_OLD / "temps.json").read_text())
temps_shared = json.loads((CAL_OLD / "temps_shared.json").read_text())

# shared-head MLP: reuse weights from v7/model/checkpoints/shared/best.pt
class SharedHeadMLP(nn.Module):
    def __init__(self, dim=256, n_labels=N_LABELS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, z): return self.mlp(z)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sh = SharedHeadMLP().to(device)
ck = torch.load(MODEL_OLD / "checkpoints" / "shared" / "best.pt", map_location=device)
state = {k.replace("shared_head.",""): v for k,v in ck["model"].items() if k.startswith("shared_head.")}
sh.load_state_dict(state, strict=True); sh.eval()

# specialist heads: load best per label from v7/ensembles/*
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1=nn.Sequential(nn.Linear(in_dim,h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2=nn.Sequential(nn.Linear(h1,h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3=nn.Sequential(nn.Linear(h2,h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out=nn.Linear(h3,1); self.short=nn.Linear(in_dim,h3)
    def forward(self, x):
        z1=self.block1(x); z2=self.block2(z1); z3=self.block3(z2)
        return self.out(z3 + self.short(x)).squeeze(-1)

def _remap(sd: dict)->dict:
    if not any(k.startswith(("b1.","b2.","b3.")) for k in sd): return sd
    m={}
    for k,v in sd.items():
        m[k.replace("b1.","block1.").replace("b2.","block2.").replace("b3.","block3.")] = v
    return m

def load_best_head(label: str) -> nn.Module:
    cands=[]
    for sd in sorted((ENS_OLD/label).glob("seed*/")):
        m = sd/"metrics.json"
        if m.exists():
            try: cands.append((float(json.loads(m.read_text()).get("best_ap", float("nan"))), sd))
            except: pass
    if not cands: raise FileNotFoundError(f"No heads for {label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best = cands[0][1]
    ckpt = torch.load(best/"best.pt", map_location=device)
    cfg = ckpt.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(cfg["in_dim"],cfg["h1"],cfg["h2"],cfg["h3"],cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(_remap(ckpt["model"]), strict=True); head.eval(); return head

HEADS = {lbl: load_best_head(lbl) for lbl in LABELS}

# ---------- Build VAL logits: shared & specialist ----------
Xv_t = torch.tensor(Xv, dtype=torch.float32, device=device)
with torch.no_grad():
    sh_logits = sh(Xv_t).cpu().numpy()                           # (Nv, L)
    spec_logits = torch.stack([HEADS[l](Xv_t) for l in LABELS],  # -> (L, Nv)
                               dim=1).cpu().numpy()              # -> (Nv, L)

# ---------- Train per-label logistic combiner on VAL ----------
def fit_temperature(logits: np.ndarray, y: np.ndarray, max_iter=200, lr=0.05) -> float:
    t = torch.tensor([1.0], dtype=torch.float32, requires_grad=True, device=device)
    x = torch.tensor(logits, dtype=torch.float32, device=device)
    y = torch.tensor(y,      dtype=torch.float32, device=device)
    opt = torch.optim.Adam([t], lr=lr)
    for _ in range(max_iter):
        opt.zero_grad(set_to_none=True)
        z = x / (t.clamp(min=1e-3))
        p = torch.sigmoid(z).clamp(1e-6, 1-1e-6)
        loss = - (y*torch.log(p) + (1-y)*torch.log(1-p)).mean()
        loss.backward(); opt.step()
    return float(t.detach().cpu().item())

def best_thresholds(y_true: np.ndarray, probs: np.ndarray) -> Dict[str,float]:
    prec, rec, th = precision_recall_curve(y_true, probs)
    eps=1e-8
    f1  = (2*prec*rec)/np.maximum(prec+rec, eps)
    beta=1.5
    fb  = ((1+beta**2)*prec*rec)/np.maximum((beta**2)*prec+rec, eps)
    th_f1 = th[np.nanargmax(f1[1:])] if th.size>0 else 0.5
    th_fb = th[np.nanargmax(fb[1:])] if th.size>0 else 0.5
    ap = float(average_precision_score(y_true, probs)) if (~np.isnan(y_true)).any() else float("nan")
    return {"th_f1": float(th_f1), "th_fbeta15": float(th_fb), "ap_val": ap}

comb_meta = {}
temps_stack = {}
thr_stack   = {}

print("Fitting per-label logistic stackers on VAL...")
for j, lbl in enumerate(LABELS):
    valid = ~Mv[:, j]
    x = np.stack([spec_logits[valid, j], sh_logits[valid, j]], axis=1)   # (Nv_valid, 2)
    y = Yv[valid, j].astype(int)
    if x.shape[0] == 0 or y.max()==y.min():
        print(f"  {lbl}: degenerate label on VAL → skipping (copying old thresholds).")
        continue
    # Logistic combiner with class_weight='balanced' to help rare positives
    lr = LogisticRegression(
        penalty="l2", C=1.0, solver="lbfgs", max_iter=500,
        class_weight="balanced", fit_intercept=True
    )
    lr.fit(x, y)
    # raw combiner logits on VAL
    z_val = lr.decision_function(x)  # shape (Nv_valid,)
    # temperature on combiner logits
    T = fit_temperature(z_val, y, max_iter=300, lr=0.05)
    temps_stack[lbl] = T
    # calibrated probs
    p_cal = 1.0 / (1.0 + np.exp(-(z_val / max(T,1e-3))))
    thr = best_thresholds(y, p_cal)
    thr_stack[lbl] = thr
    # save model per label
    lbl_dir = STACK_DIR / lbl
    lbl_dir.mkdir(parents=True, exist_ok=True)
    joblib_dump(lr, lbl_dir / "stack_lr.joblib")
    json.dump({"coef": lr.coef_.tolist(), "intercept": lr.intercept_.tolist(),
               "n": int(x.shape[0]), "pos_rate": float(y.mean())},
              open(lbl_dir / "meta.json","w"), indent=2)
    print(f"  {lbl:12s}: T={T:.3f}  AP_val={thr['ap_val']:.3f}  th_f1={thr['th_f1']:.3f}  th_fb={thr['th_fbeta15']:.3f}")

# Save global calibration/thresholds for stackers
json.dump(temps_stack, open(CAL_NEW / "temps_stack.json","w"), indent=2)
json.dump({"thresholds": thr_stack, "note": "Per-label logistic stacker on VAL"},
          open(CAL_NEW / "thresholds_stack.json","w"), indent=2)

print("\n✅ Stacking complete.")
print("  • Models         →", STACK_DIR)
print("  • Temps (stack)  →", CAL_NEW / "temps_stack.json")
print("  • Thresholds     →", CAL_NEW / "thresholds_stack.json")

Fitting per-label logistic stackers on VAL...
  NR-AR       : T=0.826  AP_val=0.173  th_f1=0.966  th_fb=0.966
  NR-AR-LBD   : T=0.602  AP_val=0.261  th_f1=0.945  th_fb=0.945
  NR-AhR      : T=1.009  AP_val=0.518  th_f1=0.818  th_fb=0.713
  NR-Aromatase: T=0.969  AP_val=0.293  th_f1=0.784  th_fb=0.622
  NR-ER       : T=0.703  AP_val=0.251  th_f1=0.672  th_fb=0.504
  NR-ER-LBD   : T=0.591  AP_val=0.116  th_f1=0.714  th_fb=0.714
  NR-PPAR-gamma: T=1.134  AP_val=0.062  th_f1=0.513  th_fb=0.513
  SR-ARE      : T=0.988  AP_val=0.344  th_f1=0.531  th_fb=0.524
  SR-ATAD5    : T=0.946  AP_val=0.167  th_f1=0.709  th_fb=0.709
  SR-HSE      : T=0.952  AP_val=0.216  th_f1=0.651  th_fb=0.651
  SR-MMP      : T=0.953  AP_val=0.445  th_f1=0.647  th_fb=0.612
  SR-p53      : T=1.015  AP_val=0.217  th_f1=0.625  th_fb=0.625

✅ Stacking complete.
  • Models         → v7_exp\stacking_asl_v1\model\stacking_lr
  • Temps (stack)  → v7_exp\stacking_asl_v1\calibration\temps_stack.json
  • Thresholds     → v7_exp\

#### B)

In [None]:
# === Cell B: Fast specialist-head retrain with ASL tweaks (heads only, fused inputs)
import json, math, os
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import average_precision_score

# ---------- Paths ----------
ROOT_OLD = Path("v7")
FUSED    = ROOT_OLD / "data" / "fused"
PREP     = ROOT_OLD / "data" / "prepared"

ROOT_NEW = Path("v7_exp") / "stacking_asl_v1"
MODEL_NEW= ROOT_NEW / "model"
ENS_NEW  = MODEL_NEW / "ensembles_v2"
LOG_NEW  = ROOT_NEW / "logs"
for p in [ENS_NEW, LOG_NEW]:
    p.mkdir(parents=True, exist_ok=True)

# ---------- Data ----------
# fused train/val + labels/masks
Xtr = np.load(FUSED / "train_fused.npy").astype(np.float32)
Xva = np.load(FUSED / "val_fused.npy").astype(np.float32)
tr = np.load(PREP / "train.npz", allow_pickle=True)
va = np.load(PREP / "val.npz", allow_pickle=True)
Ytr = tr["Y"].astype(np.float32); Mtr = tr["y_missing_mask"].astype(bool)
Yva = va["Y"].astype(np.float32); Mva = va["y_missing_mask"].astype(bool)

mani = json.loads((PREP / "dataset_manifest.json").read_text())
LABELS: List[str] = mani["labels"]; N_LABELS=len(LABELS)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Simple dataset ----------
class FusedDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]

# ---------- Label head (same architecture as before) ----------
class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1=nn.Sequential(nn.Linear(in_dim,h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2=nn.Sequential(nn.Linear(h1,h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3=nn.Sequential(nn.Linear(h2,h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out=nn.Linear(h3,1); self.short=nn.Linear(in_dim,h3)
    def forward(self, x):
        z1=self.block1(x); z2=self.block2(z1); z3=self.block3(z2)
        return self.out(z3 + self.short(x)).squeeze(-1)

# ---------- Asymmetric Loss (recall-leaning) ----------
class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=4.0, clip=0.05, eps=1e-8, alpha_pos=1.0, alpha_neg=1.0):
        super().__init__()
        self.gamma_pos=gamma_pos; self.gamma_neg=gamma_neg
        self.clip=clip; self.eps=eps; self.alpha_pos=alpha_pos; self.alpha_neg=alpha_neg
    def forward(self, logits, y):
        x_sigmoid = torch.sigmoid(logits)
        xs_pos = x_sigmoid; xs_neg = 1.0 - x_sigmoid
        if self.clip is not None and self.clip>0:
            xs_neg = (xs_neg + self.clip).clamp(max=1.0)
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1.0 - y) * torch.log(xs_neg.clamp(min=self.eps))
        if self.gamma_pos>0 or self.gamma_neg>0:
            pt = torch.where(y>=0.5, xs_pos, xs_neg)
            gamma = torch.where(y>=0.5, torch.tensor(self.gamma_pos, device=logits.device),
                                           torch.tensor(self.gamma_neg, device=logits.device))
            los_pos = los_pos * torch.pow(1.0 - xs_pos, self.gamma_pos)
            los_neg = los_neg * torch.pow(1.0 - xs_neg, self.gamma_neg)
        # class weighting
        los = -(self.alpha_pos*los_pos + self.alpha_neg*los_neg)
        return los.mean()

# ---------- Auto-pick labels to tweak ----------
# Use VAL prevalence and (if present) stack AP to identify "hard/rare" labels
thr_stack_path = ROOT_NEW / "calibration" / "thresholds_stack.json"
ap_hint = {}
if thr_stack_path.exists():
    pay = json.loads(thr_stack_path.read_text())
    for lbl, d in pay["thresholds"].items():
        ap_hint[lbl] = float(d.get("ap_val", np.nan))

# prevalence on VAL
prev_val = {LABELS[j]: float(Yva[~Mva[:,j], j].mean()) if (~Mva[:,j]).any() else np.nan for j in range(N_LABELS)}

# selection rule (editable): low prevalence (<8%) OR AP hint < 0.20
TO_TWEAK = sorted([lbl for lbl in LABELS if
                   (not np.isnan(prev_val.get(lbl, np.nan)) and prev_val[lbl] < 0.08) or
                   (not np.isnan(ap_hint.get(lbl, np.nan)) and ap_hint[lbl] < 0.20)])

print("Labels selected for ASL tweak:", TO_TWEAK)

# ---------- Training hyperparams ----------
BS = 256
EPOCHS = 6
LR = 1e-3
WD = 1e-4
PATIENCE = 3

# Recall-leaning ASL defaults for tweaked labels (you can adjust)
ASL_TWEAK = dict(gamma_neg=2.0, gamma_pos=0.0, alpha_pos=1.3, alpha_neg=1.0, clip=0.05)
ASL_BASE  = dict(gamma_neg=4.0, gamma_pos=0.0, alpha_pos=1.0, alpha_neg=1.0, clip=0.05)

# ---------- Train per label ----------
def _val_ap(head: nn.Module, X: np.ndarray, y: np.ndarray) -> float:
    head.eval()
    with torch.no_grad():
        z = head(torch.tensor(X, dtype=torch.float32, device=device)).cpu().numpy()
        p = 1.0/(1.0+np.exp(-z))
    try:
        return float(average_precision_score(y.astype(int), p))
    except Exception:
        return float("nan")

for j, lbl in enumerate(LABELS):
    # slice train/val rows with known labels
    tr_mask = ~Mtr[:, j]; va_mask = ~Mva[:, j]
    Xtr_j, ytr_j = Xtr[tr_mask], Ytr[tr_mask, j]
    Xva_j, yva_j = Xva[va_mask], Yva[va_mask, j]
    if Xtr_j.shape[0] == 0 or Xva_j.shape[0] == 0:
        print(f"[{lbl}] no data → skip.")
        continue

    # choose ASL config
    cfg = ASL_TWEAK if lbl in TO_TWEAK else ASL_BASE
    loss_fn = AsymmetricLoss(**cfg)

    head = LabelHead().to(device)
    opt = torch.optim.AdamW(head.parameters(), lr=LR, weight_decay=WD)

    ds_tr = FusedDataset(Xtr_j, ytr_j)
    dl_tr = DataLoader(ds_tr, batch_size=BS, shuffle=True, num_workers=0, pin_memory=True)

    best_ap = -1.0; best_state=None; bad=0
    for epoch in range(1, EPOCHS+1):
        head.train()
        losses=[]
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = head(xb)
            loss = loss_fn(logits, yb)
            loss.backward(); opt.step()
            losses.append(float(loss.item()))
        ap = _val_ap(head, Xva_j, yva_j)
        print(f"[{lbl}] epoch {epoch}/{EPOCHS}  train_loss={np.mean(losses):.4f}  val_AP={ap:.4f}")
        if ap > best_ap + 1e-4:
            best_ap = ap; best_state = {k: v.detach().cpu().clone() for k, v in head.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= PATIENCE: break

    # save best
    out_dir = ENS_NEW / lbl / "seed00"
    out_dir.mkdir(parents=True, exist_ok=True)
    torch.save({"model": best_state if best_state is not None else head.state_dict(),
                "config": {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30},
                "best_ap": best_ap},
               out_dir / "best.pt")
    json.dump({"label": lbl, "best_ap": best_ap, "asl_cfg": cfg, "epochs": epoch},
              open(out_dir / "metrics.json","w"), indent=2)

print("\n✅ ASL retrain complete.")
print("  • New heads →", ENS_NEW)


Labels selected for ASL tweak: ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ATAD5', 'SR-HSE', 'SR-p53']
[NR-AR] epoch 1/6  train_loss=0.0878  val_AP=0.1491
[NR-AR] epoch 2/6  train_loss=0.0643  val_AP=0.1636
[NR-AR] epoch 3/6  train_loss=0.0616  val_AP=0.1550
[NR-AR] epoch 4/6  train_loss=0.0575  val_AP=0.1657
[NR-AR] epoch 5/6  train_loss=0.0562  val_AP=0.1765
[NR-AR] epoch 6/6  train_loss=0.0555  val_AP=0.1734
[NR-AR-LBD] epoch 1/6  train_loss=0.0808  val_AP=0.1397
[NR-AR-LBD] epoch 2/6  train_loss=0.0486  val_AP=0.2891
[NR-AR-LBD] epoch 3/6  train_loss=0.0467  val_AP=0.3058
[NR-AR-LBD] epoch 4/6  train_loss=0.0472  val_AP=0.2826
[NR-AR-LBD] epoch 5/6  train_loss=0.0460  val_AP=0.2834
[NR-AR-LBD] epoch 6/6  train_loss=0.0447  val_AP=0.2944
[NR-AhR] epoch 1/6  train_loss=0.2295  val_AP=0.5028
[NR-AhR] epoch 2/6  train_loss=0.1584  val_AP=0.5323
[NR-AhR] epoch 3/6  train_loss=0.1501  val_AP=0.5355
[NR-AhR] epoch 4/6  train_loss=0.1470  val_AP=0.538

#### C) test

In [None]:
# === Compare old F1 blend vs NEW stacked F1 on VAL & TEST (self-contained) ===
# Saves under: v7_exp/stacking_asl_v1/eval/*
import json, math
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from joblib import load as joblib_load

# -------------------- Paths --------------------
ROOT_OLD   = Path("v7")
PREP       = ROOT_OLD / "data" / "prepared"
FUSED      = ROOT_OLD / "data" / "fused"
MODEL_OLD  = ROOT_OLD / "model"
ENS_OLD    = MODEL_OLD / "ensembles"
CAL_OLD    = MODEL_OLD / "calibration"

ROOT_NEW   = Path("v7_exp") / "stacking_asl_v1"
MODEL_NEW  = ROOT_NEW / "model"
ENS_NEW    = MODEL_NEW / "ensembles_v2"          # optional: ASL-tweaked heads
STACK_DIR  = MODEL_NEW / "stacking_lr"
CAL_NEW    = ROOT_NEW / "calibration"
EVAL_DIR   = ROOT_NEW / "eval"
EVAL_DIR.mkdir(parents=True, exist_ok=True)

# Toggle: use ASL-tweaked heads for *both* methods (may mismatch the stackers' training)
USE_NEW_HEADS = False   # default False (recommended)

# -------------------- Manifests --------------------
mani = json.loads((PREP / "dataset_manifest.json").read_text())
LABELS: List[str] = mani["labels"]; N_LABELS = len(LABELS)

# Baseline (old) calibration
thr_blend_v = CAL_OLD / "thresholds_blend_v2.json"
thr_blend_1 = CAL_OLD / "thresholds_blend.json"
assert thr_blend_v.exists() or thr_blend_1.exists(), "Missing thresholds_blend file(s)"
blend_payload = json.loads((thr_blend_v if thr_blend_v.exists() else thr_blend_1).read_text())
ALPHA   = float(blend_payload.get("alpha", 0.8))
THR_OLD = blend_payload["thresholds"]  # label -> {th_f1, th_fbeta15, ap_val}
temps_spec   = json.loads((CAL_OLD / "temps.json").read_text())
temps_shared = json.loads((CAL_OLD / "temps_shared.json").read_text())

# Stacked calibration
temps_stack  = json.loads((CAL_NEW / "temps_stack.json").read_text())
thr_stack    = json.loads((CAL_NEW / "thresholds_stack.json").read_text())["thresholds"]

# -------------------- Models (shared head + label heads) --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SharedHeadMLP(nn.Module):
    def __init__(self, dim=256, n_labels=N_LABELS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim*3, dim*2), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(dim*2, n_labels)
        )
    def forward(self, z): return self.mlp(z)

# load shared head weights from v7 (unchanged)
shared_ckpt = torch.load(MODEL_OLD / "checkpoints" / "shared" / "best.pt", map_location=device)
sh = SharedHeadMLP().to(device)
sh_state = {k.replace("shared_head.",""): v for k,v in shared_ckpt["model"].items() if k.startswith("shared_head.")}
sh.load_state_dict(sh_state, strict=True); sh.eval()

class LabelHead(nn.Module):
    def __init__(self, in_dim=768, h1=512, h2=256, h3=128, p=0.30):
        super().__init__()
        self.block1=nn.Sequential(nn.Linear(in_dim,h1), nn.GELU(), nn.LayerNorm(h1), nn.Dropout(p))
        self.block2=nn.Sequential(nn.Linear(h1,h2), nn.GELU(), nn.LayerNorm(h2), nn.Dropout(p))
        self.block3=nn.Sequential(nn.Linear(h2,h3), nn.GELU(), nn.LayerNorm(h3), nn.Dropout(p))
        self.out=nn.Linear(h3,1); self.short=nn.Linear(in_dim,h3)
    def forward(self, x):
        z1=self.block1(x); z2=self.block2(z1); z3=self.block3(z2)
        return self.out(z3 + self.short(x)).squeeze(-1)

def _remap(sd: dict)->dict:
    if not any(k.startswith(("b1.","b2.","b3.")) for k in sd): return sd
    m={}; 
    for k,v in sd.items():
        m[k.replace("b1.","block1.").replace("b2.","block2.").replace("b3.","block3.")] = v
    return m

def load_best_head(label: str, use_new: bool=False) -> nn.Module:
    base = ENS_NEW if use_new else ENS_OLD
    cands=[]
    for sd in sorted((base/label).glob("seed*/")):
        m = sd/"metrics.json"
        if m.exists():
            try: cands.append((float(json.loads(m.read_text()).get("best_ap", float("nan"))), sd))
            except: pass
    if not cands: raise FileNotFoundError(f"No heads for {label} under {base/label}")
    cands.sort(key=lambda x: (-1.0 if math.isnan(x[0]) else x[0]), reverse=True)
    best = cands[0][1]
    ck = torch.load(best/"best.pt", map_location=device)
    cfg = ck.get("config", {"in_dim":768,"h1":512,"h2":256,"h3":128,"dropout":0.30})
    head = LabelHead(cfg["in_dim"],cfg["h1"],cfg["h2"],cfg["h3"],cfg.get("dropout",0.30)).to(device)
    head.load_state_dict(_remap(ck["model"]), strict=True); head.eval()
    return head

HEADS = {lbl: load_best_head(lbl, use_new=USE_NEW_HEADS) for lbl in LABELS}

# Load per-label stackers (trained on OLD heads; if USE_NEW_HEADS=True, results may shift)
STACKERS = {lbl: joblib_load((STACK_DIR/lbl/"stack_lr.joblib")) for lbl in LABELS}

# -------------------- Helpers --------------------
sigmoid = lambda x: 1.0/(1.0+np.exp(-x))

def build_logits(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute shared and specialist logits on fused features."""
    X_t = torch.tensor(X, dtype=torch.float32, device=device)
    with torch.no_grad():
        sh_logits   = sh(X_t).cpu().numpy()                          # (N, L)
        spec_logits = torch.stack([HEADS[l](X_t) for l in LABELS],   # (L, N)
                                  dim=1).cpu().numpy()               # -> (N, L)
    return spec_logits, sh_logits

def old_blend_probs(spec_logits: np.ndarray, sh_logits: np.ndarray) -> np.ndarray:
    """Calibrated per-stream → blended probs with ALPHA."""
    N, L = spec_logits.shape
    P = np.zeros_like(spec_logits, dtype=np.float32)
    for j,lbl in enumerate(LABELS):
        Ts = max(float(temps_spec.get(lbl,1.0)), 1e-3)
        Th = max(float(temps_shared.get(lbl,1.0)), 1e-3)
        ps = sigmoid(spec_logits[:,j]/Ts)
        ph = sigmoid(sh_logits[:,j]/Th)
        P[:,j] = np.clip(ALPHA*ps + (1-ALPHA)*ph, 0, 1)
    return P  # (N,L)

def stack_probs(spec_logits: np.ndarray, sh_logits: np.ndarray) -> np.ndarray:
    """Per-label logistic combiner [spec,shared] → logit → temp → prob."""
    N, L = spec_logits.shape
    P = np.zeros_like(spec_logits, dtype=np.float32)
    for j,lbl in enumerate(LABELS):
        lr = STACKERS[lbl]
        z  = lr.decision_function(np.stack([spec_logits[:,j], sh_logits[:,j]], axis=1))
        T  = max(float(temps_stack.get(lbl,1.0)), 1e-3)
        P[:,j] = sigmoid(z / T)
    return P  # (N,L)

def metric_micro_macro(Y_true: np.ndarray, M_missing: np.ndarray,
                       P: np.ndarray, thresholds: Dict[str, Dict[str, float]], mode="f1") -> Dict:
    """Compute micro/macro F1 given probs and per-label thresholds."""
    assert mode in ("f1","fbeta15")
    L = Y_true.shape[1]
    valid_mask = ~M_missing
    # predictions
    preds = np.zeros_like(P, dtype=np.int32)
    for j,lbl in enumerate(LABELS):
        th = thresholds[lbl]["th_f1"] if mode=="f1" else thresholds[lbl]["th_fbeta15"]
        preds[:,j] = (P[:,j] >= float(th)).astype(np.int32)
    # micro
    y = Y_true[valid_mask].astype(int); p = preds[valid_mask]
    tp = int(((p==1)&(y==1)).sum()); fp = int(((p==1)&(y==0)).sum()); fn = int(((p==0)&(y==1)).sum())
    micro_prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
    micro_rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
    micro_f1   = (2*micro_prec*micro_rec)/(micro_prec+micro_rec) if (micro_prec+micro_rec)>0 else 0.0
    # per-label
    per=[]
    for j,lbl in enumerate(LABELS):
        mask = valid_mask[:,j]
        if not mask.any():
            per.append({"label": lbl, "precision": np.nan, "recall": np.nan, "f1": np.nan}); continue
        yj = Y_true[mask, j].astype(int); pj = preds[mask, j]
        tpj = int(((pj==1)&(yj==1)).sum()); fpj = int(((pj==1)&(yj==0)).sum()); fnj = int(((pj==0)&(yj==1)).sum())
        prec = (tpj/(tpj+fpj)) if (tpj+fpj)>0 else np.nan
        rec  = (tpj/(tpj+fnj)) if (tpj+fnj)>0 else np.nan
        f1   = (2*prec*rec/(prec+rec)) if (not np.isnan(prec) and not np.isnan(rec) and (prec+rec)>0) else np.nan
        per.append({"label": lbl, "precision": prec, "recall": rec, "f1": f1})
    macro_f1 = float(np.nanmean([d["f1"] for d in per]))
    return {"micro": {"precision": micro_prec, "recall": micro_rec, "f1": micro_f1},
            "macro_f1": macro_f1, "per_label": per, "n_samples": int(Y_true.shape[0])}

def evaluate_split(split: str) -> Dict:
    assert split in ("val","test")
    blob = np.load(PREP / f"{split}.npz", allow_pickle=True)
    Y = blob["Y"].astype(np.float32)
    M = blob["y_missing_mask"].astype(bool)
    X = np.load(FUSED / f"{split}_fused.npy").astype(np.float32)

    spec_z, sh_z = build_logits(X)

    # OLD blend (F1)
    P_old = old_blend_probs(spec_z, sh_z)
    m_old = metric_micro_macro(Y, M, P_old, THR_OLD, mode="f1")

    # NEW stacked (F1 thresholds from stack)
    P_stk = stack_probs(spec_z, sh_z)
    m_stk = metric_micro_macro(Y, M, P_stk, thr_stack, mode="f1")

    # Save detailed per-label CSV
    per_rows=[]
    for j,lbl in enumerate(LABELS):
        per_rows.append({
            "label": lbl,
            "old_precision": m_old["per_label"][j]["precision"],
            "old_recall":    m_old["per_label"][j]["recall"],
            "old_f1":        m_old["per_label"][j]["f1"],
            "stack_precision": m_stk["per_label"][j]["precision"],
            "stack_recall":    m_stk["per_label"][j]["recall"],
            "stack_f1":        m_stk["per_label"][j]["f1"],
        })
    df = pd.DataFrame(per_rows)
    df.to_csv(EVAL_DIR / f"compare_{split}.csv", index=False)

    # Save summary JSON
    out = {
        "split": split,
        "use_new_heads": USE_NEW_HEADS,
        "old_f1": m_old,
        "stack_f1": m_stk,
        "alpha_old": ALPHA
    }
    (EVAL_DIR / f"compare_{split}.json").write_text(json.dumps(out, indent=2))
    return out

print("Evaluating on VAL and TEST...")
rep_val  = evaluate_split("val")
rep_test = evaluate_split("test")

def _fmt(m): 
    return f"P={m['micro']['precision']:.3f}  R={m['micro']['recall']:.3f}  F1={m['micro']['f1']:.3f}  | macro-F1={m['macro_f1']:.3f}"

print("\n=== VAL ===")
print(" Old F1   :", _fmt(rep_val["old_f1"]))
print(" Stack F1 :", _fmt(rep_val["stack_f1"]))
print("\n=== TEST ===")
print(" Old F1   :", _fmt(rep_test["old_f1"]))
print(" Stack F1 :", _fmt(rep_test["stack_f1"]))

print("\n✅ Saved:")
print("  •", EVAL_DIR / "compare_val.json")
print("  •", EVAL_DIR / "compare_val.csv")
print("  •", EVAL_DIR / "compare_test.json")
print("  •", EVAL_DIR / "compare_test.csv")
print(f"\n(use_new_heads={USE_NEW_HEADS})")

Evaluating on VAL and TEST...

=== VAL ===
 Old F1   : P=0.244  R=0.474  F1=0.322  | macro-F1=0.311
 Stack F1 : P=0.221  R=0.506  F1=0.308  | macro-F1=0.307

=== TEST ===
 Old F1   : P=0.232  R=0.462  F1=0.309  | macro-F1=0.278
 Stack F1 : P=0.212  R=0.490  F1=0.296  | macro-F1=0.261

✅ Saved:
  • v7_exp\stacking_asl_v1\eval\compare_val.json
  • v7_exp\stacking_asl_v1\eval\compare_val.csv
  • v7_exp\stacking_asl_v1\eval\compare_test.json
  • v7_exp\stacking_asl_v1\eval\compare_test.csv

(use_new_heads=False)


  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
  sigmoid = lambda x: 1.0/(1.0+np.exp(-x))
