Cell 0 — Notebook header, imports, environment print, config loader, seeding

In [168]:
# === Cell 0: Header / Imports / Env print / Strict YAML config loader / Seeding ===

from __future__ import annotations
import os, sys, json, random
from pathlib import Path
from typing import Any, Dict

# Scientific stack
import numpy as np
import pandas as pd

# PyTorch (2.7.0+cu128; cuDNN 9.7.1; CUDA 12.8)
import torch

# We require YAML; no fallbacks.
try:
    import yaml  # pip install pyyaml
except Exception as e:
    raise ImportError("PyYAML is required. Install with: pip install pyyaml") from e

# ----------------- Project Layout (STRICT) -----------------
# This notebook must live inside the 'Model_v1' folder.
NB_ROOT = Path.cwd().resolve()
expected_root_name = "Model_v1"
if NB_ROOT.name != expected_root_name:
    raise RuntimeError(
        f"This notebook must be run from inside '{expected_root_name}'\n"
        f"Current working directory is: {NB_ROOT}"
    )

# Mandatory config path
CONFIG_PATH = NB_ROOT / "configs" / "dti_adr_v1.yaml"
if not CONFIG_PATH.exists():
    raise FileNotFoundError(
        f"Config file not found at {CONFIG_PATH}\n"
        f"Create it first under: {NB_ROOT / 'configs'}  (filename: dti_adr_v1.yaml)"
    )

# ----------------- Load Config (YAML only) -----------------
def load_cfg_yaml(path: Path) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    if not isinstance(cfg, dict):
        raise ValueError(f"Config at {path} is not a YAML mapping/dict.")
    return cfg

cfg = load_cfg_yaml(CONFIG_PATH)

# ----------------- Reproducibility -----------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed = int(cfg.get("run", {}).get("seed", 1337))
set_seed(seed)

# Determinism (slower but safer/reproducible)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Prefer high precision matmul on Ada; safe if not available.
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# AMP config
amp_cfg = cfg.get("train", {}).get("amp", {})
AMP_ENABLED = bool(amp_cfg.get("enabled", True))
AMP_PRECISION = str(amp_cfg.get("precision", "bf16")).lower()
AMP_DTYPE = (
    torch.bfloat16 if AMP_PRECISION == "bf16"
    else (torch.float16 if AMP_PRECISION == "fp16" else None)
)

# Device
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# ----------------- Pretty Environment Print -----------------
def print_env_summary():
    def yn(x): return "Yes" if x else "No"
    print("=== ML ENVIRONMENT SUMMARY (runtime) ===")
    print(f"{'OS':>18}: {os.name} / {sys.platform}")
    print(f"{'Python':>18}: {sys.version.split()[0]}")
    print(f"{'Torch':>18}: {torch.__version__}")
    print(f"{'CUDA available':>18}: {yn(torch.cuda.is_available())}")
    print(f"{'CUDA runtime':>18}: {torch.version.cuda}")
    print(f"{'cuDNN':>18}: {torch.backends.cudnn.version()}")
    print(f"{'Device':>18}: {DEVICE}")
    print(f"{'AMP enabled':>18}: {yn(AMP_ENABLED)}")
    print(f"{'AMP precision':>18}: {AMP_PRECISION if AMP_ENABLED else 'off'}")
    if torch.cuda.is_available():
        props = torch.cuda.get_device_properties(0)
        print(f"{'GPU name':>18}: {props.name}")
        print(f"{'GPU cc':>18}: {props.major}.{props.minor}")
        print(f"{'GPU mem (GiB)':>18}: {props.total_memory/2**30:.2f}")
    # print("\n=== CONFIG SNAPSHOT (key bits) ===")
    snap = {
        "run": cfg.get("run", {}),
        "model": {
            "drug_encoder": cfg.get("model", {}).get("drug_encoder"),
            "protein_encoder": cfg.get("model", {}).get("protein_encoder"),
            "shared_dim": cfg.get("model", {}).get("shared_dim"),
            "dti_head": cfg.get("model", {}).get("dti_head"),
            "contrastive": cfg.get("model", {}).get("contrastive"),
        },
        "train": {
            "batch_size": cfg.get("train", {}).get("batch_size"),
            "epochs": cfg.get("train", {}).get("epochs"),
            "amp": cfg.get("train", {}).get("amp"),
        },
        "data": {
            "K": cfg.get("data", {}).get("K"),
            "adr_root": cfg.get("data", {}).get("adr_root"),
            "dti_pairs_path": cfg.get("data", {}).get("dti_pairs_path"),
        }
    }
    # print(json.dumps(snap, indent=2))

print_env_summary()

# ----------------- Output directory & config snapshot -----------------
OUT_DIR = Path(cfg["run"]["output_dir"]).resolve()
OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(OUT_DIR / "resolved_config.json", "w", encoding="utf-8") as f:
    json.dump(cfg, f, indent=2)
print(f"\nConfig snapshot written to: {OUT_DIR / 'resolved_config.json'}")


=== ML ENVIRONMENT SUMMARY (runtime) ===
                OS: nt / win32
            Python: 3.12.11
             Torch: 2.7.0+cu128
    CUDA available: Yes
      CUDA runtime: 12.8
             cuDNN: 90701
            Device: cuda
       AMP enabled: Yes
     AMP precision: bf16
          GPU name: NVIDIA GeForce RTX 5070 Ti
            GPU cc: 12.0
     GPU mem (GiB): 15.92

Config snapshot written to: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs\dti_adr_v1\resolved_config.json


Cell 1 — Data inventory & schema checks (ADR artifacts + global stats, strict paths)

In [169]:
# === Cell 1: Data inventory & schema checks (ADR / indices / stats) ===

from pathlib import Path

def _read_json(path: Path) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def _read_parquet(path: Path) -> pd.DataFrame:
    try:
        return pd.read_parquet(path)  # requires pyarrow or fastparquet
    except Exception as e:
        raise RuntimeError(
            f"Failed to read parquet at {path}. "
            f"Install pyarrow: pip install pyarrow\nOriginal error: {e}"
        )

# Resolve ADR paths
adr_root = Path(cfg["data"]["adr_root"]).resolve()
adr_files = cfg["data"]["adr_files"]
p_idf = adr_root / adr_files["idf_table"]
p_adr_index = adr_root / adr_files["adr_index"]
p_drug_index = adr_root / adr_files["drug_index"]
p_global_stats = adr_root / adr_files["global_stats"]

# Check existence
for p in [p_idf, p_adr_index, p_drug_index, p_global_stats]:
    if not p.exists():
        raise FileNotFoundError(f"Missing required ADR artifact: {p}")

# Load artifacts
idf_table = _read_parquet(p_idf)
adr_index = _read_parquet(p_adr_index)
drug_index = _read_parquet(p_drug_index)
global_stats = _read_json(p_global_stats)

# Sanity prints
print("=== ADR Artifacts ===")
print(f"IDF table        : {p_idf}  shape={idf_table.shape}")
print(f"ADR index        : {p_adr_index}  rows={len(adr_index)}")
print(f"Drug index       : {p_drug_index}  rows={len(drug_index)}")
print(f"Global stats json: {p_global_stats}")

# Validate K (number of ADR columns to keep)
K_cfg = int(cfg["data"]["K"])
K_idf = int(len(idf_table))
if K_cfg != K_idf:
    raise ValueError(f"K mismatch: cfg K={K_cfg}, idf_table rows={K_idf}. "
                     "Ensure config matches ADR column space.")

# Inspect per-split stats.json to ensure TF-IDF column order stability and density
split_dirs = [adr_root / "train", adr_root / "val", adr_root / "test"]
SPLIT_INFO = {}

for sd in split_dirs:
    if not sd.exists():
        raise FileNotFoundError(f"Missing split directory: {sd}")

    stats_path = sd / "stats.json"
    tfidf_wide_path = sd / "tfidf_wide.parquet"
    preview_path = sd / "preview_top_tfidf.parquet"

    if not stats_path.exists() or not tfidf_wide_path.exists():
        raise FileNotFoundError(f"Missing files under {sd}: need stats.json and tfidf_wide.parquet")

    # Load small stats for quick checks
    s = _read_json(stats_path)
    df_wide = _read_parquet(tfidf_wide_path)

    # Column order check against IDF table
    # Assume idf_table has a column that defines sorted ADR keys; use the first column name as ADR id
    idf_adr_col = idf_table.columns[0]
    idf_keys = idf_table[idf_adr_col].tolist() if idf_adr_col != 0 else idf_table.iloc[:,0].tolist()

    # tfidf_wide should have the exact K ADR columns in the same order, aside from its ID column
    wide_cols = list(df_wide.columns)
    # Heuristic: first column may be an ID (e.g., rxcui or drug_chembl_id). Detect & strip if non-numeric/non-adr.
    if len(wide_cols) != K_cfg and len(wide_cols) == K_cfg + 1:
        adr_cols = wide_cols[1:]
        id_col = wide_cols[0]
    elif len(wide_cols) == K_cfg:
        adr_cols = wide_cols
        id_col = None
    else:
        raise ValueError(f"{tfidf_wide_path} columns={len(wide_cols)} not matching expected K (={K_cfg}) or K+1.")

    # If the ADR column names match the IDF keys exactly, good; otherwise warn.
    # (Some pipelines store ADR columns as positional indices 0..K-1. In that case, we skip strict name check.)
    name_match = False
    if isinstance(adr_cols[0], str) and isinstance(idf_keys[0], (str, int)):
        # Attempt direct name match
        name_match = (adr_cols == idf_keys)
    else:
        name_match = True  # Can't compare reliably; assume positional match

    # Density stats
    nnz = float(s.get("nnz", -1))
    density = float(s.get("density", -1.0))
    n_rows = int(s.get("n_rows", len(df_wide)))
    SPLIT_INFO[sd.name] = {
        "n_rows": n_rows,
        "has_preview": (sd / "preview_top_tfidf.parquet").exists(),
        "id_col": id_col,
        "adr_name_match_idf": name_match,
        "nnz": nnz,
        "density": density,
    }

# Summarize
print("\n=== ADR split checks ===")
for k, v in SPLIT_INFO.items():
    print(f"{k:>5} | rows={v['n_rows']:>6} | preview={v['has_preview']} | "
          f"ADR-names==IDF? {v['adr_name_match_idf']} | nnz={v['nnz']} | dens={v['density']:.6f}")

print("\nAll ADR artifacts look consistent ✅")


=== ADR Artifacts ===
IDF table        : F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\1. Adr_embeddings\idf_table.parquet  shape=(4048, 3)
ADR index        : F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\1. Adr_embeddings\adr_index.parquet  rows=4817
Drug index       : F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\1. Adr_embeddings\drug_index.parquet  rows=1028
Global stats json: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\1. Adr_embeddings\global_stats.json

=== ADR split checks ===
train | rows=   719 | preview=True | ADR-names==IDF? False | nnz=47047.0 | dens=0.016165
  val | rows=   154 | preview=True | ADR-names==IDF? False | nnz=10324.0 | dens=0.016561
 test | rows=   155 | preview=True | ADR-names==IDF? False | nnz=11222.0 | dens=0.017885

All ADR artifacts look consistent ✅

=== ADR split checks ===
train | rows=   719 | preview=True | ADR-names==IDF? False | nnz=47047.0 | dens=0.016165
  val | rows=   154 | preview=True | ADR-names==IDF? False | nnz=

Cell 2 — Load chosen Drug & Protein embeddings (dtype-safe, shape-checked)

In [170]:
# === Cell 2: Load chosen Drug & Protein embeddings (dtype-safe, shape-checked) ===

from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch

# Expect config keys:
#   cfg["model"]["drug_encoder"] in {"chemberta", "smiles2vec", "egnn"}
#   cfg["model"]["protein_encoder"] in {"esm", "gvp"}

# ---------- Helper: parquet reader (pyarrow already checked in Cell 1) ----------
def read_parquet_strict(p: Path) -> pd.DataFrame:
    if not p.exists():
        raise FileNotFoundError(f"Missing embedding parquet: {p}")
    try:
        return pd.read_parquet(p)
    except Exception as e:
        raise RuntimeError(f"Failed to read parquet at {p}\n{e}")

# ---------- Helper: coerce 'embedding' column → 2D float32 array ----------
def coerce_embedding_column(df: pd.DataFrame, col: str = "embedding") -> np.ndarray:
    """
    Coerce a DataFrame's 'embedding' column into a 2D np.ndarray[float32]
    Handles cases where dtype=object with lists/ndarrays, or JSON-like strings.
    """
    if col not in df.columns:
        raise KeyError(f"Expected embedding column '{col}' not found in columns: {list(df.columns)[:8]} ...")

    raw = df[col].values
    # If first element is already a numpy array
    if isinstance(raw[0], np.ndarray):
        arr = np.stack(raw, axis=0)
    else:
        # Could be python lists or strings like "[1.2, -0.3, ...]"
        processed = []
        for i, v in enumerate(raw):
            if isinstance(v, list):
                processed.append(np.asarray(v, dtype=np.float32))
                continue
            if isinstance(v, (bytes, bytearray)):
                v = v.decode("utf-8")
            if isinstance(v, str):
                # Attempt to parse quick JSON-ish arrays without the cost of full json if possible
                v_str = v.strip()
                if v_str.startswith("[") and v_str.endswith("]"):
                    try:
                        parsed = json.loads(v_str)
                        processed.append(np.asarray(parsed, dtype=np.float32))
                        continue
                    except Exception as e:
                        raise ValueError(f"Row {i}: failed to parse string embedding JSON: {e}")
                else:
                    raise TypeError(f"Row {i}: string embedding not in JSON list format.")
            # Fallback if someone stored tuples or other iterables
            try:
                processed.append(np.asarray(v, dtype=np.float32))
            except Exception as e:
                raise TypeError(f"Row {i}: cannot coerce embedding to float32 array. Got type={type(v)}; err={e}")
        arr = np.stack(processed, axis=0)

    if arr.ndim != 2:
        raise ValueError(f"Embedding array must be 2D [N, D]; got shape {arr.shape}")
    # Ensure float32 and contiguous
    if arr.dtype != np.float32:
        arr = arr.astype(np.float32, copy=False)
    return np.ascontiguousarray(arr)

# ---------- Helper: pick ID column names ----------
def detect_drug_id_column(df: pd.DataFrame) -> str:
    # Strict expectation from your dir summary: "drug_chembl_id"
    if "drug_chembl_id" in df.columns:
        return "drug_chembl_id"
    # Fallbacks (in case of future variants)
    for c in ["chembl_id", "drug_id", "id"]:
        if c in df.columns:
            return c
    raise KeyError(f"Could not find a drug ID column in {list(df.columns)}")

def detect_protein_id_column(df: pd.DataFrame) -> str:
    # ESM file uses "id", GVP uses "uniprot_id" — support both.
    if "uniprot_id" in df.columns:
        return "uniprot_id"
    if "id" in df.columns:
        return "id"
    # Conservative fallback
    for c in ["protein_id", "target_uniprot_id"]:
        if c in df.columns:
            return c
    raise KeyError(f"Could not find a protein ID column in {list(df.columns)}")

# ---------- Load Drug Embeddings ----------
drug_choice = str(cfg["model"]["drug_encoder"]).lower()
drug_paths = cfg["data"]["drug_embeddings"]
if drug_choice not in drug_paths:
    raise KeyError(f"Unknown drug_encoder '{drug_choice}'. Available: {list(drug_paths.keys())}")

p_drug = Path(drug_paths[drug_choice])
if not p_drug.is_absolute():
    p_drug = (NB_ROOT / p_drug).resolve()

df_drug = read_parquet_strict(p_drug)
drug_id_col = detect_drug_id_column(df_drug)
drug_vecs = coerce_embedding_column(df_drug, col="embedding")
N_drug, D_drug = drug_vecs.shape

# Deduplicate by ID if necessary (keep first occurrence)
if df_drug[drug_id_col].duplicated().any():
    keep_mask = ~df_drug[drug_id_col].duplicated()
    df_drug = df_drug.loc[keep_mask].reset_index(drop=True)
    drug_vecs = drug_vecs[keep_mask.values]
    N_drug, D_drug = drug_vecs.shape

# Build CPU tensors (we'll move to GPU during training to avoid long-lived CUDA mem)
DRUG_ID_LIST = df_drug[drug_id_col].astype(str).tolist()
DRUG_TENSOR = torch.from_numpy(drug_vecs)  # float32, CPU
assert DRUG_TENSOR.dtype == torch.float32 and DRUG_TENSOR.ndim == 2

# Fast lookup map: chembl_id -> row index
DRUG_ID2IDX = {k: i for i, k in enumerate(DRUG_ID_LIST)}

print(f"Loaded DRUG embeddings from: {p_drug.name}")
print(f"  IDs: column='{drug_id_col}', unique={len(DRUG_ID_LIST)}")
print(f"  Tensor shape: {tuple(DRUG_TENSOR.shape)} (dtype={DRUG_TENSOR.dtype})")

# ---------- Load Protein Embeddings ----------
prot_choice = str(cfg["model"]["protein_encoder"]).lower()
prot_paths = cfg["data"]["protein_embeddings"]
if prot_choice not in prot_paths:
    raise KeyError(f"Unknown protein_encoder '{prot_choice}'. Available: {list(prot_paths.keys())}")

p_prot = Path(prot_paths[prot_choice])
if not p_prot.is_absolute():
    p_prot = (NB_ROOT / p_prot).resolve()

df_prot = read_parquet_strict(p_prot)
prot_id_col = detect_protein_id_column(df_prot)
prot_vecs = coerce_embedding_column(df_prot, col="embedding")
N_prot, D_prot = prot_vecs.shape

# Deduplicate by ID if necessary
if df_prot[prot_id_col].duplicated().any():
    keep_mask = ~df_prot[prot_id_col].duplicated()
    df_prot = df_prot.loc[keep_mask].reset_index(drop=True)
    prot_vecs = prot_vecs[keep_mask.values]
    N_prot, D_prot = prot_vecs.shape

PROT_ID_LIST = df_prot[prot_id_col].astype(str).tolist()
PROT_TENSOR = torch.from_numpy(prot_vecs)  # float32, CPU
assert PROT_TENSOR.dtype == torch.float32 and PROT_TENSOR.ndim == 2

PROT_ID2IDX = {k: i for i, k in enumerate(PROT_ID_LIST)}

print(f"\nLoaded PROTEIN embeddings from: {p_prot.name}")
print(f"  IDs: column='{prot_id_col}', unique={len(PROT_ID_LIST)}")
print(f"  Tensor shape: {tuple(PROT_TENSOR.shape)} (dtype={PROT_TENSOR.dtype})")

# ---------- Shape sanity & warnings ----------
if N_drug != len(set(DRUG_ID_LIST)):
    raise AssertionError("Drug IDs are not unique after deduplication.")
if N_prot != len(set(PROT_ID_LIST)):
    raise AssertionError("Protein IDs are not unique after deduplication.")

print("\n=== Embedding Dimensionalities ===")
print(f"  Drug encoder '{drug_choice}': N={N_drug}, D={D_drug}")
print(f"  Protein encoder '{prot_choice}': N={N_prot}, D={D_prot}")

# Store global dims for later modules (adapters)
DRUG_IN_DIM = int(D_drug)
PROT_IN_DIM = int(D_prot)

# ---------- AMP & dtype guidance (informational) ----------
if AMP_ENABLED and AMP_DTYPE is torch.float16:
    print("\n[Note] AMP fp16 selected. On Ada, bf16 is often more stable than fp16.")
elif AMP_ENABLED and AMP_DTYPE is torch.bfloat16:
    print("\n[Note] AMP bf16 selected. Keep model weights & losses in fp32; autocast forward only.")

print("\nEmbedding loaders ready ✅")


Loaded DRUG embeddings from: EGNN_drug_embeddings_v2.parquet
  IDs: column='drug_chembl_id', unique=1028
  Tensor shape: (1028, 256) (dtype=torch.float32)

Loaded PROTEIN embeddings from: GVP-GNN_protein_embeddings.parquet
  IDs: column='uniprot_id', unique=2385
  Tensor shape: (2385, 1024) (dtype=torch.float32)

=== Embedding Dimensionalities ===
  Drug encoder 'egnn': N=1028, D=256
  Protein encoder 'gvp': N=2385, D=1024

[Note] AMP bf16 selected. Keep model weights & losses in fp32; autocast forward only.

Embedding loaders ready ✅


In [171]:
# === Patch: sanitize + standardize embedding tensors (run after Cell 2) ===
import torch

def _sanitize_and_standardize(T: torch.Tensor, clip_sigma: float = 5.0):
    """
    - Replace non-finite values with per-dim median
    - Z-score standardize per dimension
    - Clip to ±clip_sigma to tame extreme outliers
    Returns standardized tensor and (mu, sd) for logging.
    """
    T = T.clone().to(torch.float32, non_blocking=True)
    # non-finite → median per column
    nonfinite = ~torch.isfinite(T)
    if nonfinite.any():
        # compute column-wise medians with nan handling
        X = T.masked_fill(nonfinite, float('nan'))
        med = torch.nanmedian(X, dim=0).values
        row_idx, col_idx = nonfinite.nonzero(as_tuple=True)
        T[row_idx, col_idx] = med[col_idx]

    mu = T.mean(dim=0)
    sd = T.std(dim=0).clamp_min(1e-6)
    Z = (T - mu) / sd
    Z = Z.clamp_(-clip_sigma, clip_sigma).contiguous()
    return Z, mu, sd

# Apply to both embedding banks
DRUG_TENSOR, DRUG_MU, DRUG_SD = _sanitize_and_standardize(DRUG_TENSOR)
PROT_TENSOR, PROT_MU, PROT_SD = _sanitize_and_standardize(PROT_TENSOR)

# Quick norms after standardization (should be ~sqrt(D) in L2 but z-scored per-dim)
print("Post-standardization norms:",
      f"drug mean={DRUG_TENSOR.norm(dim=1).mean().item():.3f} ",
      f"prot mean={PROT_TENSOR.norm(dim=1).mean().item():.3f}")


Post-standardization norms: drug mean=12.285  prot mean=0.997


Cell 3 — DTI pairs + ADR TF-IDF split loaders, Dataset, BalancedSampler, Collate

In [172]:
# === Cell 3: Load DTI pairs + ADR TF-IDF, build datasets, balanced sampler, collate ===

from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, Sampler

# ---------- Strict paths ----------
pairs_path = (NB_ROOT / cfg["data"]["dti_pairs_path"]).resolve() \
    if not Path(cfg["data"]["dti_pairs_path"]).is_absolute() else Path(cfg["data"]["dti_pairs_path"])
if not pairs_path.exists():
    raise FileNotFoundError(f"Missing DTI pairs file: {pairs_path}")

# ---------- Load DTI pairs (full) ----------
pairs_df = pd.read_parquet(pairs_path)
required_cols = ["drug_chembl_id", "target_uniprot_id", "label", "rxcui"]
missing = [c for c in required_cols if c not in pairs_df.columns]
if missing:
    raise KeyError(f"DTI pairs parquet missing columns: {missing}")

# Normalize dtypes to string for id keys; label to int64
pairs_df["drug_chembl_id"] = pairs_df["drug_chembl_id"].astype(str)
pairs_df["target_uniprot_id"] = pairs_df["target_uniprot_id"].astype(str)
pairs_df["rxcui"] = pairs_df["rxcui"].astype(str)
pairs_df["label"] = pairs_df["label"].astype(np.int64)

print("=== DTI pairs (global) ===")
print(pairs_df[required_cols].head(3))
print(f"Pairs total: {len(pairs_df)} | pos={(pairs_df['label']==1).sum()} | neg={(pairs_df['label']==0).sum()}")

# ---------- ADR TF-IDF loaders ----------
def load_tfidf_split(split_name: str) -> Tuple[List[str], torch.Tensor]:
    """
    Returns:
      rxcui_list: list[str] of length U_split
      tfidf_tensor: torch.FloatTensor [U_split, K]
    """
    assert split_name in ("train", "val", "test")
    sd = (NB_ROOT / cfg["data"]["adr_root"] / split_name).resolve()
    tfidf_wide_path = sd / "tfidf_wide.parquet"
    if not tfidf_wide_path.exists():
        raise FileNotFoundError(f"Missing TF-IDF matrix at {tfidf_wide_path}")

    df_wide = pd.read_parquet(tfidf_wide_path)

    # Detect ID column (first col often ID). Expect K or K+1 columns.
    K = int(cfg["data"]["K"])
    cols = list(df_wide.columns)
    if len(cols) == K:
        # No explicit ID column — use drug_index.parquet to map, but your pipeline stores ID in col 0 in practice.
        # Enforce the safer convention: require K+1 with an ID column.
        raise ValueError(
            f"{tfidf_wide_path} has exactly K columns (K={K}) but no ID column to align by rxcui.\n"
            f"Expected K+1 with ID in column 0."
        )
    elif len(cols) == K + 1:
        id_col = cols[0]
        adr_cols = cols[1:]
    else:
        raise ValueError(f"{tfidf_wide_path} columns={len(cols)} not matching expected K+1 (K={K}).")

    # Build rxcui list and TF-IDF matrix (float32, contiguous)
    rxcui_list = df_wide[id_col].astype(str).tolist()
    tfidf_mat = df_wide[adr_cols].to_numpy(dtype=np.float32, copy=False)
    if tfidf_mat.ndim != 2 or tfidf_mat.shape[1] != K:
        raise AssertionError(f"TF-IDF shape mismatch: got {tfidf_mat.shape}, expected [U,{K}]")
    tfidf_tensor = torch.from_numpy(np.ascontiguousarray(tfidf_mat, dtype=np.float32))
    return rxcui_list, tfidf_tensor

R_train, T_train = load_tfidf_split("train")
R_val,   T_val   = load_tfidf_split("val")
R_test,  T_test  = load_tfidf_split("test")

print("\n=== TF-IDF splits ===")
print(f"train: U={len(R_train)}  T.shape={tuple(T_train.shape)} dtype={T_train.dtype}")
print(f"  val: U={len(R_val)}    T.shape={tuple(T_val.shape)} dtype={T_val.dtype}")
print(f" test: U={len(R_test)}   T.shape={tuple(T_test.shape)} dtype={T_test.dtype}")

# rxcui -> row index for each split
R2I_train = {r: i for i, r in enumerate(R_train)}
R2I_val   = {r: i for i, r in enumerate(R_val)}
R2I_test  = {r: i for i, r in enumerate(R_test)}

# ---------- Build split-specific DTI datasets ----------
# We align pairs to a split if their rxcui exists in that split's TF-IDF table.

def filter_pairs_for_split(pairs: pd.DataFrame, split: str) -> pd.DataFrame:
    if split == "train":
        ok = pairs["rxcui"].isin(R2I_train)
    elif split == "val":
        ok = pairs["rxcui"].isin(R2I_val)
    elif split == "test":
        ok = pairs["rxcui"].isin(R2I_test)
    else:
        raise ValueError(split)
    df = pairs.loc[ok, required_cols].copy().reset_index(drop=True)
    # Also drop pairs whose drug/protein embeddings are missing
    df = df[df["drug_chembl_id"].isin(DRUG_ID2IDX)].copy()
    df = df[df["target_uniprot_id"].isin(PROT_ID2IDX)].copy()
    df = df.reset_index(drop=True)
    return df

pairs_train = filter_pairs_for_split(pairs_df, "train")
pairs_val   = filter_pairs_for_split(pairs_df, "val")
pairs_test  = filter_pairs_for_split(pairs_df, "test")

def pos_neg_counts(df: pd.DataFrame) -> Tuple[int, int]:
    pos = int((df["label"] == 1).sum())
    neg = int((df["label"] == 0).sum())
    return pos, neg

print("\n=== Split-aligned DTI pairs ===")
for name, df in [("train", pairs_train), ("val", pairs_val), ("test", pairs_test)]:
    pos, neg = pos_neg_counts(df)
    print(f"{name:>5}: n={len(df)} | pos={pos} | neg={neg}")

# ---------- Dataset definition ----------
@dataclass
class PairRow:
    drug_idx: int
    prot_idx: int
    label: int
    rxcui: str

class DTIDataset(Dataset):
    def __init__(self, df_pairs: pd.DataFrame, rxcui2idx: Dict[str, int]):
        self.rows: List[PairRow] = []
        self.rxcui2idx = rxcui2idx
        # Build rows
        for _, row in df_pairs.iterrows():
            d_id = row["drug_chembl_id"]
            p_id = row["target_uniprot_id"]
            y    = int(row["label"])
            rx   = row["rxcui"]
            # map to indices (already filtered, but assert anyway)
            if d_id not in DRUG_ID2IDX or p_id not in PROT_ID2IDX or rx not in rxcui2idx:
                continue
            self.rows.append(PairRow(
                drug_idx=DRUG_ID2IDX[d_id],
                prot_idx=PROT_ID2IDX[p_id],
                label=y,
                rxcui=rx
            ))
        if len(self.rows) == 0:
            raise ValueError("Empty dataset after alignment. Check IDs & TF-IDF splits.")

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> PairRow:
        return self.rows[idx]

# Instantiate datasets
ds_train = DTIDataset(pairs_train, R2I_train)
ds_val   = DTIDataset(pairs_val,   R2I_val)
ds_test  = DTIDataset(pairs_test,  R2I_test)

print("\nDatasets ready ✅")
print(f"  train: {len(ds_train)} rows")
print(f"    val: {len(ds_val)} rows")
print(f"   test: {len(ds_test)} rows")

# ---------- Balanced Batch Sampler ----------
# class BalancedBatchSampler(Sampler[List[int]]):
#     """
#     Yields lists of indices for balanced batches (B/2 pos, B/2 neg).
#     With replacement where necessary.
#     """
#     def __init__(self, dataset: DTIDataset, batch_size: int):
#         assert batch_size % 2 == 0, "Batch size must be even for balanced sampling."
#         self.ds = dataset
#         self.batch_size = batch_size
#         # Precompute pos/neg index pools
#         self.pos_idx = [i for i, r in enumerate(self.ds.rows) if r.label == 1]
#         self.neg_idx = [i for i, r in enumerate(self.ds.rows) if r.label == 0]
#         if len(self.pos_idx) == 0 or len(self.neg_idx) == 0:
#             raise ValueError("Balanced sampler requires both positive and negative samples.")
#         self.n_batches = math.ceil(len(self.ds) / self.batch_size)

#     def __len__(self) -> int:
#         return self.n_batches

#     def __iter__(self):
#         B2 = self.batch_size // 2
#         for _ in range(self.n_batches):
#             # Sample with replacement if pools smaller than needed
#             pos = np.random.choice(self.pos_idx, size=B2, replace=(len(self.pos_idx) < B2))
#             neg = np.random.choice(self.neg_idx, size=B2, replace=(len(self.neg_idx) < B2))
#             batch = np.concatenate([pos, neg])
#             np.random.shuffle(batch)
#             yield batch.tolist()

class BalancedBatchSampler(Sampler[List[int]]):
    def __init__(self, dataset: DTIDataset, batch_size: int, pos_frac: float = 0.65):
        assert 0.0 < pos_frac < 1.0
        assert batch_size % 2 == 0, "Batch size must be even."
        self.ds = dataset
        self.batch_size = batch_size
        self.pos_frac = pos_frac
        self.pos_idx = [i for i, r in enumerate(self.ds.rows) if r.label == 1]
        self.neg_idx = [i for i, r in enumerate(self.ds.rows) if r.label == 0]
        if not self.pos_idx or not self.neg_idx:
            raise ValueError("Need both pos and neg.")
        self.n_batches = math.ceil(len(self.ds) / self.batch_size)
        
    def __len__(self) -> int:
        return self.n_batches

    def __iter__(self):
        Bp = int(round(self.batch_size * self.pos_frac))
        Bn = self.batch_size - Bp
        for _ in range(self.n_batches):
            pos = np.random.choice(self.pos_idx, size=Bp, replace=(len(self.pos_idx) < Bp))
            neg = np.random.choice(self.neg_idx, size=Bn, replace=(len(self.neg_idx) < Bn))
            batch = np.concatenate([pos, neg])
            np.random.shuffle(batch)
            yield batch.tolist()

# ---------- pos_weight for BCE ----------
def compute_pos_weight(dataset: DTIDataset) -> torch.Tensor:
    y = np.array([r.label for r in dataset.rows], dtype=np.int64)
    n_pos = int((y == 1).sum())
    n_neg = int((y == 0).sum())
    if n_pos == 0:
        raise ValueError("No positive samples in dataset.")
    pw = float(n_neg) / float(n_pos)
    return torch.tensor(pw, dtype=torch.float32)

POS_WEIGHT_TRAIN = compute_pos_weight(ds_train)
# print(f"\npos_weight (train) = {POS_WEIGHT_TRAIN.item():.4f}")
PW_SCALE = 1.5  # try 1.2–2.0
POS_WEIGHT_TRAIN = POS_WEIGHT_TRAIN * PW_SCALE
print(f"pos_weight (scaled) = {POS_WEIGHT_TRAIN.item():.4f}")

# ---------- Collate function ----------
def collate_pairs_with_unique_drugs(
    rows: List[PairRow],
    split: str
) -> Dict[str, torch.Tensor]:
    """
    Returns:
      x_d [B, D_drug] float32 (CPU)
      x_p [B, D_prot] float32 (CPU)
      y   [B] float32
      t   [U, K] float32 (TF-IDF for unique drugs in batch for this split)
      pair_to_u [B] int64  (maps pair i -> row in t)
    """
    B = len(rows)
    d_idx = np.fromiter((r.drug_idx for r in rows), dtype=np.int64, count=B)
    p_idx = np.fromiter((r.prot_idx for r in rows), dtype=np.int64, count=B)
    y_arr = np.fromiter((r.label for r in rows), dtype=np.float32, count=B)

    # Unique drugs by RXCUI mapping per split
    if split == "train":
        rmap = R2I_train
        T = T_train
    elif split == "val":
        rmap = R2I_val
        T = T_val
    elif split == "test":
        rmap = R2I_test
        T = T_test
    else:
        raise ValueError(split)

    # For each pair, find the rxcui row; then build the unique set for the batch
    rxcui_rows = [rmap[r.rxcui] for r in rows]
    unique_rows, inverse = np.unique(np.asarray(rxcui_rows, dtype=np.int64), return_inverse=True)
    # inverse gives pair_to_u mapping
    pair_to_u = torch.from_numpy(inverse.astype(np.int64))

    # Gather TF-IDF for U unique drugs
    t = T[torch.from_numpy(unique_rows)]  # [U, K], still CPU float32

    # Gather embedding rows for pairs
    x_d = DRUG_TENSOR[torch.from_numpy(d_idx)]
    x_p = PROT_TENSOR[torch.from_numpy(p_idx)]
    y   = torch.from_numpy(y_arr)

    # Final integrity checks
    assert x_d.dtype == torch.float32 and x_p.dtype == torch.float32 and t.dtype == torch.float32
    assert x_d.ndim == 2 and x_p.ndim == 2 and t.ndim == 2
    assert y.ndim == 1 and pair_to_u.ndim == 1

    return {
        "x_d": x_d,           # [B, D_drug]
        "x_p": x_p,           # [B, D_prot]
        "y": y,               # [B]
        "t": t,               # [U, K]
        "pair_to_u": pair_to_u,  # [B]
    }

print("\nCollate, datasets, and sampler scaffolding are ready ✅\n"
      "Next step: define adapters + heads + contrastives and wire the training loop.")


=== DTI pairs (global) ===
  drug_chembl_id target_uniprot_id  label  rxcui
0     CHEMBL1000            O15245      0  20610
1     CHEMBL1000            P08183      1  20610
2     CHEMBL1000            P35367      1  20610
Pairs total: 34741 | pos=12234 | neg=22507
  drug_chembl_id target_uniprot_id  label  rxcui
0     CHEMBL1000            O15245      0  20610
1     CHEMBL1000            P08183      1  20610
2     CHEMBL1000            P35367      1  20610
Pairs total: 34741 | pos=12234 | neg=22507

=== TF-IDF splits ===
train: U=719  T.shape=(719, 4048) dtype=torch.float32
  val: U=154    T.shape=(154, 4048) dtype=torch.float32
 test: U=155   T.shape=(155, 4048) dtype=torch.float32

=== Split-aligned DTI pairs ===
train: n=22310 | pos=8439 | neg=13871
  val: n=4835 | pos=1861 | neg=2974
 test: n=7596 | pos=1934 | neg=5662

=== TF-IDF splits ===
train: U=719  T.shape=(719, 4048) dtype=torch.float32
  val: U=154    T.shape=(154, 4048) dtype=torch.float32
 test: U=155   T.shape=(155, 40

Cell 4 — Adapters, Heads, Contrastives scaffolding, AMP-guarded forward

In [173]:
# === Cell 4: Adapters, Heads, Contrastives scaffolding, AMP-guarded forward ===

import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------- Utilities --------------------
def l2_normalize_rows(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Row-wise L2 normalization. x: [N, D] -> normalized [N, D]
    """
    return x / (x.norm(p=2, dim=1, keepdim=True) + eps)

def first_occurrence_indices(pair_to_u: torch.Tensor, B: Optional[int] = None) -> torch.Tensor:
    """
    Given pair_to_u [B], returns indices 'pos' of length U such that
    u_unique[j] = u[pos[j]] aligns with t[j] for ADR head.
    Picks the first pair position that maps to each unique u index.
    """
    if B is None:
        B = pair_to_u.numel()
    U = int(pair_to_u.max().item() + 1)
    pos = torch.full((U,), -1, dtype=torch.long, device=pair_to_u.device)
    # linear scan; first seen wins
    for i in range(B):
        u_idx = int(pair_to_u[i].item())
        if pos[u_idx] == -1:
            pos[u_idx] = i
            if (pos != -1).all():
                break
    if (pos == -1).any():
        # Defensive fallback: fill remaining with the earliest valid index (keeps computation defined)
        fill_idx = int((pair_to_u == 0).nonzero(as_tuple=True)[0][0].item()) if (pair_to_u == 0).any() else 0
        pos[pos == -1] = fill_idx
    return pos  # shape [U]

# -------------------- Modules --------------------
class Adapter(nn.Module):
    """
    Two-layer MLP + LayerNorm projector into shared space of dim d.
    """
    def __init__(self, in_dim: int, d: int, hidden_ratio: int = 2, p_drop: float = 0.10):
        super().__init__()
        h = hidden_ratio * d
        self.net = nn.Sequential(
            nn.Linear(in_dim, h),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(h, d),
            nn.LayerNorm(d)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class DTIHeadCosine(nn.Module):
    """
    Cosine similarity with learnable scale -> logit.
    """
    def __init__(self, d: int):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0), dtype=torch.float32))  # init scale≈10

    def forward(self, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # u, v: [B, d]
        u_n = l2_normalize_rows(u)
        v_n = l2_normalize_rows(v)
        scale = self.logit_scale.exp().clamp(1e-3, 1e3)
        # cosine in [-1,1]; scale to logits
        logits = scale * (u_n * v_n).sum(dim=1)
        return logits  # [B]

class DTIHeadBilinear(nn.Module):
    """
    Bilinear u^T W v + b -> logit.
    """
    def __init__(self, d: int):
        super().__init__()
        self.W = nn.Parameter(torch.empty(d, d))
        self.b = nn.Parameter(torch.zeros(1))
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))

    def forward(self, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # [B,d] @ [d,d] -> [B,d], then dot with v
        uW = torch.matmul(u, self.W)  # [B, d]
        logits = (uW * v).sum(dim=1) + self.b  # [B]
        return logits

class ADRHead(nn.Module):
    """
    Linear(d -> K) with Softplus to enforce non-negative TF-IDF predictions.
    """
    def __init__(self, d: int, K: int):
        super().__init__()
        self.out = nn.Linear(d, K)

    def forward(self, u_unique: torch.Tensor) -> torch.Tensor:
        y = self.out(u_unique)  # [U, K]
        return F.softplus(y)    # non-negative

# -------------------- Main Multi-Task Model --------------------
class MultiTaskModel(nn.Module):
    """
    f_d: drug adapter (in_dim=D_drug -> d)
    f_p: protein adapter (in_dim=D_prot -> d)
    f_a: ADR adapter over TF-IDF (K -> d)  [used for contrastive DA and prototype building]
    h_dti: binding head (cosine or bilinear)
    h_adr: ADR regression head (Softplus)
    """
    def __init__(self, Dd: int, Dp: int, K: int, cfg: dict):
        super().__init__()
        mcfg = cfg["model"]
        d = int(mcfg["shared_dim"])
        hidden_ratio = int(mcfg.get("hidden_ratio", 2))
        p_drop = float(mcfg.get("p_drop", 0.10))

        self.d = d
        self.K = K

        # Adapters
        self.f_d = Adapter(Dd, d, hidden_ratio=hidden_ratio, p_drop=p_drop)
        self.f_p = Adapter(Dp, d, hidden_ratio=hidden_ratio, p_drop=p_drop)
        self.f_a = Adapter(K,  d, hidden_ratio=hidden_ratio, p_drop=p_drop)

        # DTI head
        head_type = str(mcfg.get("dti_head", "cosine")).lower()
        if head_type == "cosine":
            self.h_dti = DTIHeadCosine(d)
        elif head_type == "bilinear":
            self.h_dti = DTIHeadBilinear(d)
        else:
            raise ValueError(f"Unknown dti_head: {head_type}")
        self.dti_head_type = head_type

        # ADR head
        self.h_adr = ADRHead(d, K)

        # Contrastive settings
        c = mcfg.get("contrastive", {})
        self.use_dp = bool(c.get("use_dp", True))
        self.use_da = bool(c.get("use_da", True))
        self.tau    = float(c.get("tau", 0.07))

    @torch.inference_mode(False)
    def forward(
        self,
        x_d: torch.Tensor,         # [B, Dd] float32 CPU/GPU
        x_p: torch.Tensor,         # [B, Dp] float32 CPU/GPU
        t: torch.Tensor,           # [U, K]  float32 CPU/GPU (TF-IDF for unique drugs in batch)
        pair_to_u: torch.Tensor,   # [B] int64 mapping each pair -> its u index (0..U-1)
        amp_enabled: bool = AMP_ENABLED,
        amp_dtype: Optional[torch.dtype] = AMP_DTYPE
    ) -> Dict[str, torch.Tensor]:
        """
        Returns:
          u [B,d], v [B,d], w [U,d]
          logits [B]
          y_adr_hat [U, K]
          u_unique_index [U]   # indices into batch rows used to form u_unique
        """
        # Move to module device (supports CPU debug)
        dev = next(self.parameters()).device
        x_d = x_d.to(dev, dtype=torch.float32, non_blocking=True)
        x_p = x_p.to(dev, dtype=torch.float32, non_blocking=True)
        t   = t.to(dev,   dtype=torch.float32, non_blocking=True)
        pair_to_u = pair_to_u.to(dev)

        # AMP guard: forward only autocast (losses kept in fp32 outside)
        ac_dtype = amp_dtype if (amp_enabled and amp_dtype is not None) else torch.float32
        ctx = torch.autocast(device_type=dev.type, dtype=ac_dtype) if (amp_enabled and amp_dtype is not None and dev.type == "cuda") else torch.cuda.amp.autocast(enabled=False)

        with ctx:
            # Adapters
            u = self.f_d(x_d)   # [B, d]
            v = self.f_p(x_p)   # [B, d]
            w = self.f_a(t)     # [U, d]

            # DTI head
            logits = self.h_dti(u, v)  # [B]

            # Build u_unique aligned with 't' (same U order) via first occurrences in the batch
            pos = first_occurrence_indices(pair_to_u, B=x_d.shape[0])  # [U]
            u_unique = u.index_select(0, pos)  # [U, d]

            # ADR head
            y_adr_hat = self.h_adr(u_unique)  # [U, K] (Softplus)

        return {
            "u": u, "v": v, "w": w,
            "logits": logits,
            "y_adr_hat": y_adr_hat,
            "u_unique_index": pos,   # useful for diagnostics / contrastive DA if needed
        }

# -------------------- Contrastive losses (InfoNCE) --------------------
def info_nce_from_pairs(
    u: torch.Tensor, v: torch.Tensor, labels: torch.Tensor, tau: float = 0.07
) -> Tuple[torch.Tensor, int]:
    """
    Drug–Protein InfoNCE on positive pairs only.
    Args:
      u, v: [B, d]
      labels: [B] in {0,1}
    Returns:
      loss, n_pos_used
    """
    dev = u.device
    pos_mask = (labels.to(dev) == 1)
    if pos_mask.sum() <= 1:
        # Not enough positives to build a batchwise contrast — return 0
        return u.new_zeros(()), int(pos_mask.sum().item())
    u_pos = u[pos_mask]
    v_pos = v[pos_mask]
    # Normalize rows
    u_pos = l2_normalize_rows(u_pos)
    v_pos = l2_normalize_rows(v_pos)
    # Similarity matrices
    logits_uv = (u_pos @ v_pos.T) / tau   # [P, P]
    logits_vu = (v_pos @ u_pos.T) / tau   # [P, P]
    targets = torch.arange(u_pos.size(0), device=dev, dtype=torch.long)
    loss = F.cross_entropy(logits_uv, targets) + F.cross_entropy(logits_vu, targets)
    return loss * 0.5, int(pos_mask.sum().item())

def info_nce_drug_adr(
    u_unique: torch.Tensor, w: torch.Tensor, tau: float = 0.07
) -> torch.Tensor:
    """
    Drug–ADR InfoNCE: diagonal positives between u_unique and w (same U).
    """
    if u_unique.size(0) <= 1:
        return u_unique.new_zeros(())
    u_n = l2_normalize_rows(u_unique)
    w_n = l2_normalize_rows(w)
    logits_uw = (u_n @ w_n.T) / tau   # [U, U]
    logits_wu = (w_n @ u_n.T) / tau   # [U, U]
    targets = torch.arange(u_n.size(0), device=u_n.device, dtype=torch.long)
    loss = F.cross_entropy(logits_uw, targets) + F.cross_entropy(logits_wu, targets)
    return loss * 0.5

# -------------------- Instantiate model (moves later to DEVICE) --------------------
K = int(cfg["data"]["K"])
model = MultiTaskModel(DRUG_IN_DIM, PROT_IN_DIM, K, cfg)
model = model.to(DEVICE)
print(model.__class__.__name__, "initialized on", DEVICE)
print(f"Shared dim d = {model.d}, ADR K = {model.K}, head = {model.dti_head_type}")


MultiTaskModel initialized on cuda
Shared dim d = 512, ADR K = 4048, head = cosine


Cell 5 — Loss assembly, optimizer/scheduler, and a single-batch dry run

In [174]:
# === Cell 5: Loss assembly, optimizer/scheduler, and a single-batch dry run ===

import math
from typing import Dict, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

# ---------------------- Loss builders ----------------------
def build_dti_bce_loss(pos_weight: torch.Tensor) -> nn.Module:
    """
    BCEWithLogits with pos_weight on the SAME device/dtype as logits.
    We'll recast pos_weight at call-time to logits.dtype/device for safety.
    """
    # We'll wrap a callable to inject the (device,dtype)-corrected pos_weight each step.
    class _BCEWithDynamicPos(nn.Module):
        def __init__(self, pw: torch.Tensor):
            super().__init__()
            self.register_buffer("pos_weight_buf", pw.float(), persistent=False)

        def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            # Ensure fp32 loss math even under autocast (logits will be autocast dtype)
            logits_f32 = logits.float()
            target_f32 = target.float()
            # Move/cast pos_weight
            pos_w = self.pos_weight_buf.to(device=logits_f32.device, dtype=logits_f32.dtype)
            return F.binary_cross_entropy_with_logits(logits_f32, target_f32, pos_weight=pos_w)

    return _BCEWithDynamicPos(pos_weight)


def build_adr_regression_loss(cfg_loss: Dict) -> nn.Module:
    """
    ADR regression to TF-IDF using Huber (default) or MSE.
    """
    mode = str(cfg_loss.get("type", "huber")).lower()
    if mode == "huber":
        delta = float(cfg_loss.get("delta", 1.0))
        class _Huber(nn.Module):
            def __init__(self, delta: float):
                super().__init__()
                self.delta = delta
            def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                # Keep loss in fp32
                return F.huber_loss(pred.float(), target.float(), delta=self.delta, reduction="mean")
        return _Huber(delta)
    elif mode == "mse":
        class _MSE(nn.Module):
            def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
                return F.mse_loss(pred.float(), target.float(), reduction="mean")
        return _MSE()
    else:
        raise ValueError(f"Unknown ADR loss type: {mode}")


# ---------------------- Optimizer & Scheduler ----------------------
def build_optimizer_and_scheduler(model: nn.Module, cfg: dict, steps_per_epoch: int) -> Tuple[torch.optim.Optimizer, Optional[LambdaLR]]:
    opt_cfg = cfg["train"]["optimizer"]
    sch_cfg = cfg["train"]["scheduler"]

    lr = float(opt_cfg.get("lr", 2e-4))
    wd = float(opt_cfg.get("weight_decay", 1e-4))

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=wd)

    if not bool(sch_cfg.get("use_cosine", True)):
        return optimizer, None

    warmup_pct = float(sch_cfg.get("warmup_pct", 0.05))
    total_epochs = int(cfg["train"]["epochs"])
    total_steps = max(1, steps_per_epoch * total_epochs)
    warmup_steps = int(total_steps * warmup_pct)

    def lr_lambda(step: int):
        if step < warmup_steps:
            return max(1e-8, step / max(1, warmup_steps))  # linear warmup
        # cosine decay from 1.0 to 0.0 over remaining steps
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = LambdaLR(optimizer, lr_lambda)
    return optimizer, scheduler


# ---------------------- Glue losses together ----------------------
lambda_dti = float(cfg["train"]["loss_weights"]["lambda_dti"])
lambda_adr = float(cfg["train"]["loss_weights"]["lambda_adr"])
lambda_con = float(cfg["train"]["loss_weights"]["lambda_con"])

# BCE pos_weight (already computed on train set in Cell 3 -> POS_WEIGHT_TRAIN)
criterion_dti = build_dti_bce_loss(POS_WEIGHT_TRAIN)

# ADR loss config (allow override in cfg, else default huber)
cfg.setdefault("losses", {})
cfg["losses"].setdefault("adr", {"type": "huber", "delta": 1.0})
criterion_adr = build_adr_regression_loss(cfg["losses"]["adr"])

# AMP scaler: useful only for fp16; bf16 doesn't need GradScaler
use_scaler = AMP_ENABLED and (AMP_PRECISION == "fp16")
scaler = torch.cuda.amp.GradScaler(enabled=use_scaler)

# ---------------------- One-step dry run (no weight update) ----------------------
# We'll assemble a small balanced batch from ds_train and run forward + losses.
from torch.utils.data import DataLoader

test_bs = min(8, int(cfg["train"]["batch_size"]))  # small test; must be even
if test_bs % 2 == 1:
    test_bs += 1

sampler = BalancedBatchSampler(ds_train, batch_size=test_bs)
# Windows-safe DataLoader: num_workers=0 to avoid multiprocessing pitfalls with CUDA context on Win
loader = DataLoader(ds_train, batch_sampler=sampler, num_workers=0, collate_fn=lambda rows: collate_pairs_with_unique_drugs(rows, split="train"))

# Build optimizer/scheduler for later (need steps_per_epoch)
steps_per_epoch = max(1, len(ds_train) // max(2, int(cfg["train"]["batch_size"])))
optimizer, scheduler = build_optimizer_and_scheduler(model, cfg, steps_per_epoch=steps_per_epoch)

# Contrastive flags
use_dp = model.use_dp
use_da = model.use_da
tau    = model.tau

# Get one batch
batch = next(iter(loader))
# Move label to device for DP contrastive selection and BCE
y = batch["y"].to(DEVICE, dtype=torch.float32, non_blocking=True)

# Forward pass (AMP only affects forward)
out = model(
    x_d=batch["x_d"].to(DEVICE, non_blocking=True),
    x_p=batch["x_p"].to(DEVICE, non_blocking=True),
    t=batch["t"].to(DEVICE, non_blocking=True),
    pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
    amp_enabled=AMP_ENABLED,
    amp_dtype=AMP_DTYPE
)

logits = out["logits"]                   # [B]
y_adr_hat = out["y_adr_hat"]             # [U, K]
u = out["u"]; v = out["v"]; w = out["w"] # [B,d], [B,d], [U,d]
pos = out["u_unique_index"]              # [U]

# Gather TF-IDF targets aligned to y_adr_hat (already aligned in collate via t)
t_targets = batch["t"].to(DEVICE, dtype=torch.float32)

# Compute each loss part in fp32
L_dti = criterion_dti(logits, y)  # BCEWithLogits (pos_weighted)
L_adr = criterion_adr(y_adr_hat, t_targets)

# Contrastive pieces
L_dp = torch.zeros((), device=DEVICE)
L_da = torch.zeros((), device=DEVICE)
if lambda_con > 0.0:
    if use_dp:
        L_dp, npos_used = info_nce_from_pairs(u, v, y, tau=tau)
    if use_da:
        # Build u_unique aligned with t: out["u_unique_index"] maps u rows -> unique drugs
        u_unique = u.index_select(0, pos.to(DEVICE))
        L_da = info_nce_drug_adr(u_unique, w, tau=tau)

L_con = (L_dp + L_da)

L_total = (lambda_dti * L_dti) + (lambda_adr * L_adr) + (lambda_con * L_con)

# Print diagnostics (dtype/shape/values)
print("=== Dry Run: Loss breakdown ===")
print(f"L_dti (BCE pos-w): {L_dti.item():.6f}")
print(f"L_adr (regress) : {L_adr.item():.6f}")
print(f"L_dp (InfoNCE)  : {float(L_dp.item()):.6f}")
print(f"L_da (InfoNCE)  : {float(L_da.item()):.6f}")
print(f"--> L_total     : {L_total.item():.6f}")

print("\n=== Sanity: tensor shapes/dtypes ===")
print(f"logits: {tuple(logits.shape)} {logits.dtype}  | y: {tuple(y.shape)} {y.dtype}")
print(f"y_adr_hat: {tuple(y_adr_hat.shape)} {y_adr_hat.dtype}  | t_targets: {tuple(t_targets.shape)} {t_targets.dtype}")
print(f"u: {tuple(u.shape)}, v: {tuple(v.shape)}, w: {tuple(w.shape)}  | pos: {tuple(pos.shape)}")
print(f"AMP enabled: {AMP_ENABLED} / {AMP_PRECISION}")
print("Dry run forward+loss ✅  (no weights updated)")

# NOTE: We will construct the full training loop (Cell 6) next,
# including AMP autocast for forward, fp32 loss math, GradScaler (fp16 only),
# gradient clipping, optimizer.step(), and optional scheduler.step().


=== Dry Run: Loss breakdown ===
L_dti (BCE pos-w): 1.726675
L_adr (regress) : 0.302557
L_dp (InfoNCE)  : 1.628032
L_da (InfoNCE)  : 0.000000
--> L_total     : 2.203559

=== Sanity: tensor shapes/dtypes ===
logits: (8,) torch.float32  | y: (8,) torch.float32
y_adr_hat: (8, 4048) torch.float32  | t_targets: (8, 4048) torch.float32
u: (8, 512), v: (8, 512), w: (8, 512)  | pos: (8,)
AMP enabled: True / bf16
Dry run forward+loss ✅  (no weights updated)


  scaler = torch.cuda.amp.GradScaler(enabled=use_scaler)


Cell 6 — Full training loop with validation, early stopping, and checkpointing

In [175]:
# === Cell 6: Full training loop with validation, early stopping, and checkpointing ===

import math, time, json
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# ----------------- DataLoaders -----------------
BATCH_SIZE = int(cfg["train"]["batch_size"])

# train_loader = DataLoader(
#     ds_train,
#     batch_sampler=BalancedBatchSampler(ds_train, batch_size=BATCH_SIZE),
#     num_workers=0,  # Windows-safe
#     collate_fn=lambda rows: collate_pairs_with_unique_drugs(rows, split="train"),
# )

train_loader = DataLoader(
    ds_train,
    batch_sampler=BalancedBatchSampler(ds_train, batch_size=BATCH_SIZE, pos_frac=0.65),
    num_workers=0,
    collate_fn=lambda rows: collate_pairs_with_unique_drugs(rows, split="train"),
)

val_loader = DataLoader(
    ds_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=lambda rows: collate_pairs_with_unique_drugs(rows, split="val"),
)

test_loader = DataLoader(
    ds_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=lambda rows: collate_pairs_with_unique_drugs(rows, split="test"),
)

steps_per_epoch = max(1, len(train_loader))
optimizer, scheduler = build_optimizer_and_scheduler(model, cfg, steps_per_epoch)

# ----------------- Metric helpers (pure numpy) -----------------
def _binary_classification_curves(y_true: np.ndarray, y_score: np.ndarray):
    """Returns ROC and PR curves sorted by score descending."""
    # Sort by score desc
    order = np.argsort(-y_score)
    y_true = y_true[order]
    y_score = y_score[order]

    # Cum sums for TP/FP
    tp = np.cumsum(y_true == 1)
    fp = np.cumsum(y_true == 0)
    fn = tp[-1] - tp
    tn = fp[-1] - fp

    # ROC
    tpr = tp / np.maximum(tp[-1], 1)
    fpr = fp / np.maximum(fp[-1], 1)

    # Precision-Recall
    precision = tp / np.maximum(tp + fp, 1)
    recall = tp / np.maximum(tp[-1], 1)

    return fpr, tpr, precision, recall, y_true, y_score

def _auc(x: np.ndarray, y: np.ndarray) -> float:
    """Trapezoidal integral; assumes x is monotonic increasing."""
    if len(x) < 2:
        return float("nan")
    return float(np.trapz(y, x))

def roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
    fpr, tpr, *_ = _binary_classification_curves(y_true, y_score)
    # Ensure increasing fpr
    order = np.argsort(fpr)
    return _auc(fpr[order], tpr[order])

def pr_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
    *_, precision, recall, _, _ = _binary_classification_curves(y_true, y_score)
    # Ensure increasing recall
    order = np.argsort(recall)
    return _auc(recall[order], precision[order])

def best_f1_threshold(y_true: np.ndarray, y_score: np.ndarray) -> Tuple[float, float]:
    """Returns (best_threshold, best_f1)."""
    # evaluate on unique sorted scores plus 0.5 heuristic
    thresholds = np.unique(y_score)
    best_t, best_f1 = 0.5, 0.0
    for t in thresholds:
        y_hat = (y_score >= t).astype(np.int32)
        tp = (y_hat & (y_true == 1)).sum()
        fp = (y_hat & (y_true == 0)).sum()
        fn = ((1 - y_hat) & (y_true == 1)).sum()
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec) > 0 else 0.0
        if f1 > best_f1:
            best_f1, best_t = float(f1), float(t)
    return best_t, best_f1

def rmse(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.sqrt(np.mean((a - b) ** 2)))

def mae(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.mean(np.abs(a - b)))

def recall_at_k(y_true_row: np.ndarray, y_score_row: np.ndarray, k: int) -> float:
    """y_true_row, y_score_row are 1D arrays for a single drug; positives are non-zero TF-IDF."""
    k = min(k, y_score_row.size)
    topk = np.argpartition(-y_score_row, k-1)[:k]
    hits = (y_true_row[topk] > 0).sum()
    total_pos = (y_true_row > 0).sum()
    return float(hits / max(total_pos, 1))

def ndcg_at_k(y_true_row: np.ndarray, y_score_row: np.ndarray, k: int) -> float:
    k = min(k, y_score_row.size)
    order = np.argsort(-y_score_row)[:k]
    gains = y_true_row[order]  # TF-IDF as relevance proxy
    discounts = 1.0 / np.log2(np.arange(2, k+2))
    dcg = float(np.sum(gains * discounts))
    # Ideal DCG
    ideal_order = np.argsort(-y_true_row)[:k]
    ideal = float(np.sum(y_true_row[ideal_order] * discounts))
    return float(dcg / ideal) if ideal > 0 else 0.0

# ----------------- Evaluation -----------------
def evaluate(model: torch.nn.Module, loader: DataLoader, split_name: str, val_threshold: Optional[float]=None, ks: List[int]=[5,10]) -> Dict[str, float]:
    model.eval()
    y_all, p_all = [], []

    # ADR metrics: aggregate per unique-drug rows over all batches
    adr_preds, adr_tgts = [], []

    with torch.no_grad():
        for batch in loader:
            y = batch["y"].to(DEVICE, dtype=torch.float32, non_blocking=True)
            out = model(
                x_d=batch["x_d"].to(DEVICE, non_blocking=True),
                x_p=batch["x_p"].to(DEVICE, non_blocking=True),
                t=batch["t"].to(DEVICE, non_blocking=True),
                pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
                amp_enabled=AMP_ENABLED,
                amp_dtype=AMP_DTYPE
            )
            logits = out["logits"]
            # prob = torch.sigmoid(logits).detach().float().cpu().numpy()
            
            
            # After loading T if present
            calib_path = OUT_DIR / "calibration.json"
            if calib_path.exists():
                T_star = json.loads(calib_path.read_text())["temperature"]
            else:
                T_star = 1.0

            prob = torch.sigmoid(out["logits"] / T_star).detach().float().cpu().numpy()
            
            
            
            y_np  = y.detach().cpu().numpy().astype(np.int32)

            y_all.append(y_np)
            p_all.append(prob)

            # ADR accumulators
            y_adr_hat = out["y_adr_hat"].detach().float().cpu().numpy()
            t_targets = batch["t"].detach().float().cpu().numpy()
            adr_preds.append(y_adr_hat)
            adr_tgts.append(t_targets)

    if len(y_all) == 0:
        return {"dti_pr_auc": float("nan"), "dti_roc_auc": float("nan"), "dti_f1": float("nan")}

    y_all = np.concatenate(y_all, axis=0)
    p_all = np.concatenate(p_all, axis=0)

    pr = pr_auc(y_all, p_all)
    roc = roc_auc(y_all, p_all)

    # Threshold for F1 (choose on val; use provided val_threshold for test)
    if split_name == "val" or val_threshold is None:
        thr, f1 = best_f1_threshold(y_all, p_all)
    else:
        thr = float(val_threshold)
        y_hat = (p_all >= thr).astype(np.int32)
        tp = (y_hat & (y_all == 1)).sum()
        fp = (y_hat & (y_all == 0)).sum()
        fn = ((1 - y_hat) & (y_all == 1)).sum()
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec) > 0 else 0.0

    # ADR metrics across all unique-drug rows concatenated (duplicates across batches are fine—macro average)
    if len(adr_preds) > 0:
        P = np.concatenate(adr_preds, axis=0)  # [U_total, K]
        T = np.concatenate(adr_tgts, axis=0)   # [U_total, K]
        adr_rmse = rmse(P, T)
        adr_mae_ = mae(P, T)

        # Ranking metrics
        rec_at = {}
        ndcg_at = {}
        for k in ks:
            # average across drug rows
            recs = [recall_at_k(T[i], P[i], k) for i in range(P.shape[0])]
            ndcgs = [ndcg_at_k(T[i], P[i], k) for i in range(P.shape[0])]
            rec_at[f"recall@{k}"] = float(np.mean(recs))
            ndcg_at[f"ndcg@{k}"] = float(np.mean(ndcgs))
    else:
        adr_rmse = adr_mae_ = float("nan")
        rec_at, ndcg_at = {}, {}

    metrics = {
        "dti_pr_auc": float(pr),
        "dti_roc_auc": float(roc),
        "dti_f1": float(f1),
        "dti_thr": float(thr),
        "adr_rmse": float(adr_rmse),
        "adr_mae": float(adr_mae_),
        **rec_at,
        **ndcg_at,
    }
    return metrics

# ----------------- Training Loop -----------------
EPOCHS = int(cfg["train"]["epochs"])
CLIP_NORM = float(cfg["train"]["clip_grad_norm"])

best_val = -np.inf
best_epoch = -1
best_thr = 0.5
patience = int(cfg.get("train", {}).get("early_stop_patience", 7))
no_improve = 0
history = []

print(f"\n=== Training for {EPOCHS} epochs ===")
for epoch in range(1, EPOCHS+1):
    model.train()
    t0 = time.time()
    
    if epoch <= 5:
        lambda_con_epoch = lambda_con
    else:
        lambda_con_epoch = lambda_con * 0.25  # try 0.5 or 0.25
    # then use lambda_con_epoch below:
    
    L_total = (lambda_dti * L_dti) + (lambda_adr * L_adr) + (lambda_con_epoch * L_con)
    running = {"L_total": 0.0, "L_dti": 0.0, "L_adr": 0.0, "L_con": 0.0}
    steps = 0

    for batch in train_loader:
        steps += 1
        y = batch["y"].to(DEVICE, dtype=torch.float32, non_blocking=True)

        # Forward
        out = model(
            x_d=batch["x_d"].to(DEVICE, non_blocking=True),
            x_p=batch["x_p"].to(DEVICE, non_blocking=True),
            t=batch["t"].to(DEVICE, non_blocking=True),
            pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
            amp_enabled=AMP_ENABLED,
            amp_dtype=AMP_DTYPE
        )
        logits = out["logits"]
        y_adr_hat = out["y_adr_hat"]
        u, v, w = out["u"], out["v"], out["w"]
        pos = out["u_unique_index"]

        # Losses (fp32 math)
        L_dti = criterion_dti(logits, y)
        L_adr = criterion_adr(y_adr_hat, batch["t"].to(DEVICE, dtype=torch.float32))
        L_dp = torch.zeros((), device=DEVICE)
        L_da = torch.zeros((), device=DEVICE)
        if lambda_con > 0.0:
            if model.use_dp:
                L_dp, _ = info_nce_from_pairs(u, v, y, tau=model.tau)
            if model.use_da:
                u_unique = u.index_select(0, pos.to(DEVICE))
                L_da = info_nce_drug_adr(u_unique, w, tau=model.tau)
        L_con = L_dp + L_da
        L_total = (lambda_dti * L_dti) + (lambda_adr * L_adr) + (lambda_con * L_con)

        # Backward
        optimizer.zero_grad(set_to_none=True)
        if scaler.is_enabled():
            scaler.scale(L_total).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            L_total.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # Logs
        running["L_total"] += float(L_total.item())
        running["L_dti"]   += float(L_dti.item())
        running["L_adr"]   += float(L_adr.item())
        running["L_con"]   += float(L_con.item())

    # Epoch logs
    for k in running:
        running[k] /= max(1, steps)
    t1 = time.time()
    print(f"Epoch {epoch:03d} | {t1-t0:.1f}s | "
          f"L: {running['L_total']:.4f} (dti {running['L_dti']:.4f} | adr {running['L_adr']:.4f} | con {running['L_con']:.4f})")

    # ---------- Validation ----------
    m_val = evaluate(model, val_loader, split_name="val", val_threshold=None, ks=cfg["eval"]["metrics"].get("adr_k", [5,10]))
    print(f"  VAL: PR-AUC {m_val['dti_pr_auc']:.4f} | ROC-AUC {m_val['dti_roc_auc']:.4f} | F1 {m_val['dti_f1']:.4f} @thr={m_val['dti_thr']:.3f} "
          f"| ADR RMSE {m_val['adr_rmse']:.4f} MAE {m_val['adr_mae']:.4f}")

    # Early stopping on dti_pr_auc
    current = m_val["dti_pr_auc"]
    if current > best_val:
        best_val = current
        best_epoch = epoch
        best_thr = m_val["dti_thr"]
        no_improve = 0

        # Save checkpoint
        ckpt = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
            "metrics_val": m_val,
            "best_thr": best_thr,
            "config": cfg,
        }
        torch.save(ckpt, OUT_DIR / "best.pt")
        with open(OUT_DIR / "best_metrics_val.json", "w", encoding="utf-8") as f:
            json.dump(m_val, f, indent=2)
        print(f"  ✔ Saved new best checkpoint to {OUT_DIR / 'best.pt'}")
    else:
        no_improve += 1
        print(f"  (no improve: {no_improve}/{patience})")

    history.append({"epoch": epoch, "train": running, "val": m_val})

    if no_improve >= patience:
        print(f"Early stopping at epoch {epoch} (best epoch {best_epoch}, best val PR-AUC {best_val:.4f})")
        break

# ----------------- Final evaluation on TEST with best threshold -----------------
print("\n=== Evaluating on TEST with best val threshold ===")
m_test = evaluate(model, test_loader, split_name="test", val_threshold=best_thr, ks=cfg["eval"]["metrics"].get("adr_k", [5,10]))
print(f"TEST: PR-AUC {m_test['dti_pr_auc']:.4f} | ROC-AUC {m_test['dti_roc_auc']:.4f} | F1 {m_test['dti_f1']:.4f} @thr={best_thr:.3f} "
      f"| ADR RMSE {m_test['adr_rmse']:.4f} MAE {m_test['adr_mae']:.4f}")
with open(OUT_DIR / "test_metrics.json", "w", encoding="utf-8") as f:
    json.dump({**m_test, "thr": best_thr}, f, indent=2)

# ----------------- Optional: export ADR prototypes -----------------
art_cfg = cfg.get("artifacts", {})
if bool(art_cfg.get("save_prototypes", True)):
    proto_path = Path(art_cfg.get("prototypes_path", str(OUT_DIR / "prototypes_C.npy")))
    model.eval()
    with torch.no_grad():
        # Pass the entire TRAIN TF-IDF through f_a to produce prototypes per ADR via mean of drug embeddings.
        # But our design defines prototypes per ADR as mean of f_a(TFIDF_d) over drugs where TF-IDF(d,k) > 0.
        # We'll compute W = f_a(T_train) -> [U_train, d], then aggregate by ADR column.
        W = model.f_a(T_train.to(DEVICE, dtype=torch.float32))
        W = W.detach().cpu().numpy()  # [U, d]
        T_np = T_train.cpu().numpy()  # [U, K]

        K = T_np.shape[1]
        d = W.shape[1]
        C = np.zeros((K, d), dtype=np.float32)
        for k in range(K):
            mask = T_np[:, k] > 0
            if mask.any():
                C[k] = W[mask].mean(axis=0)
            else:
                C[k] = 0.0

        np.save(proto_path, C)
        print(f"Saved ADR prototypes to {proto_path}")

# ----------------- Save training history -----------------
with open(OUT_DIR / "history.json", "w", encoding="utf-8") as f:
    json.dump(history, f, indent=2)

print("\nTraining loop complete ✅")



=== Training for 40 epochs ===
Epoch 001 | 5.4s | L: 2.1000 (dti 0.9358 | adr 0.2263 | con 5.2548)
Epoch 001 | 5.4s | L: 2.1000 (dti 0.9358 | adr 0.2263 | con 5.2548)


  return float(np.trapz(y, x))


  VAL: PR-AUC 0.6517 | ROC-AUC 0.7593 | F1 0.6970 @thr=0.644 | ADR RMSE 0.4043 MAE 0.3772
  ✔ Saved new best checkpoint to F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs\dti_adr_v1\best.pt
Epoch 002 | 5.6s | L: 1.8280 (dti 0.7538 | adr 0.0272 | con 5.3028)
Epoch 002 | 5.6s | L: 1.8280 (dti 0.7538 | adr 0.0272 | con 5.3028)
  VAL: PR-AUC 0.6428 | ROC-AUC 0.7622 | F1 0.7043 @thr=0.655 | ADR RMSE 0.1030 MAE 0.0923
  (no improve: 1/7)
  VAL: PR-AUC 0.6428 | ROC-AUC 0.7622 | F1 0.7043 @thr=0.655 | ADR RMSE 0.1030 MAE 0.0923
  (no improve: 1/7)
Epoch 003 | 5.4s | L: 1.7955 (dti 0.7330 | adr 0.0037 | con 5.3033)
Epoch 003 | 5.4s | L: 1.7955 (dti 0.7330 | adr 0.0037 | con 5.3033)
  VAL: PR-AUC 0.6086 | ROC-AUC 0.7591 | F1 0.6994 @thr=0.654 | ADR RMSE 0.0562 MAE 0.0499
  (no improve: 2/7)
  VAL: PR-AUC 0.6086 | ROC-AUC 0.7591 | F1 0.6994 @thr=0.654 | ADR RMSE 0.0562 MAE 0.0499
  (no improve: 2/7)
Epoch 004 | 5.3s | L: 1.7787 (dti 0.7213 | adr 0.0017 | con 5.2824)
Epoch 004 | 5.3s | L:

In [176]:
# === Cell 12: Baselines & Leakage Audit ===
import numpy as np, pandas as pd, torch, math
from collections import Counter

# 1) Random baseline & frequency baseline (DTI)
def eval_loader_probs(loader, prob_fn):
    y_all, p_all = [], []
    with torch.no_grad():
        for batch in loader:
            y = batch["y"].cpu().numpy().astype(np.int32)
            p = prob_fn(batch)
            y_all.append(y); p_all.append(p)
    if not y_all: return np.array([]), np.array([])
    return np.concatenate(y_all), np.concatenate(p_all)

def pr_auc_simple(y, p):
    if y.size == 0: return float("nan")
    order = np.argsort(-p)
    y = y[order]
    tp = np.cumsum(y==1); fp = np.cumsum(y==0)
    prec = tp / np.maximum(tp+fp, 1)
    rec  = tp / np.maximum(tp[-1], 1)
    # trapezoid integral with recall ascending
    order = np.argsort(rec)
    return float(np.trapz(prec[order], rec[order]))

# Random baseline (DTI)
rng = np.random.RandomState(1337)
y_t, p_rand = eval_loader_probs(test_loader, lambda b: rng.rand(len(b["y"])))
pr_rand = pr_auc_simple(y_t, p_rand)
print(f"[Baseline] DTI Random PR-AUC: {pr_rand:.3f} (pos rate ~{(y_t.mean() if y_t.size else float('nan')):.3f})")

# Frequency baseline by protein: P(y=1 | protein) computed on TRAIN
prot_counts = Counter()
prot_pos = Counter()
for r in ds_train.rows:
    prot_counts[r.prot_idx] += 1
    prot_pos[r.prot_idx]    += int(r.label == 1)
def prot_freq_prob(b):
    idx = [PROT_ID2IDX[p] if isinstance(p, str) else p for p in []]  # not used; collate gives indices only
    # we have prot indices in rows; rebuild from x_p index gathered by collate
    # batch gives only tensors; we approximate by re-encoding and nearest rows:
    # Instead, compute mean P(y=1) over training; apply global prior as fallback.
    prior = (sum(prot_pos.values()) / max(1, sum(prot_counts.values())))
    # We cannot access prot_idx directly from batch; use prior baseline:
    return np.full(len(b["y"]), prior, dtype=np.float32)
y_t, p_prior = eval_loader_probs(test_loader, prot_freq_prob)
print(f"[Baseline] DTI Global prior PR-AUC: {pr_auc_simple(y_t, p_prior):.3f}")

# 2) Leakage/overlap checks
def overlap_counts(A: pd.DataFrame, B: pd.DataFrame, key_cols):
    a = set(map(tuple, A[key_cols].values))
    b = set(map(tuple, B[key_cols].values))
    return len(a & b), len(a), len(b)

# pairs_train/val/test defined in Cell 3
keys = ["drug_chembl_id","target_uniprot_id","label"]
ov_tv, n_t, n_v = overlap_counts(pairs_train, pairs_val, keys)
ov_tt, n_t2, n_te = overlap_counts(pairs_train, pairs_test, keys)
ov_vt, n_v2, n_te2 = overlap_counts(pairs_val, pairs_test, keys)
print(f"Pair overlap (train↔val): {ov_tv}/{n_t} vs {n_v} | (train↔test): {ov_tt}/{n_t2} vs {n_te} | (val↔test): {ov_vt}/{n_v2} vs {n_te2}")

# We also check marginal overlap by drug or protein (distribution shift)
def marginal_overlap(A, B, col):
    a = set(A[col].astype(str))
    b = set(B[col].astype(str))
    return len(a & b), len(a), len(b)
od_tv = marginal_overlap(pairs_train, pairs_val, "drug_chembl_id")
op_tv = marginal_overlap(pairs_train, pairs_val, "target_uniprot_id")
od_tt = marginal_overlap(pairs_train, pairs_test, "drug_chembl_id")
op_tt = marginal_overlap(pairs_train, pairs_test, "target_uniprot_id")
print(f"Drug overlap train↔val: {od_tv[0]}/{od_tv[1]} vs {od_tv[2]} | Protein overlap train↔val: {op_tv[0]}/{op_tv[1]} vs {op_tv[2]}")
print(f"Drug overlap train↔test:{od_tt[0]}/{od_tt[1]} vs {od_tt[2]} | Protein overlap train↔test:{op_tt[0]}/{op_tt[1]} vs {op_tt[2]}")

# 3) Embedding coverage & scales
missing_drugs = pairs_df[~pairs_df["drug_chembl_id"].isin(DRUG_ID2IDX)].shape[0]
missing_prots = pairs_df[~pairs_df["target_uniprot_id"].isin(PROT_ID2IDX)].shape[0]
print(f"Missing embeddings → drugs: {missing_drugs} rows | proteins: {missing_prots} rows")

# Norm stats
d_norm = DRUG_TENSOR.norm(dim=1).numpy()
p_norm = PROT_TENSOR.norm(dim=1).numpy()
print(f"Drug emb norm: mean={d_norm.mean():.3f} std={d_norm.std():.3f} | Protein emb norm: mean={p_norm.mean():.3f} std={p_norm.std():.3f}")

# 4) Label-shuffle test (should collapse to ~0.5 ROC / pos-rate PR)
def shuffle_test():
    import copy
    # duplicate dataset with shuffled labels
    y = np.array([r.label for r in ds_train.rows])
    y_shuf = np.random.RandomState(123).permutation(y)
    # quick check: if a fully-trained model still shows high ROC on shuffled labels, there’s leakage
    print("Label-shuffle test: expected ROC ≈ 0.5 after retrain; run only if you plan to retrain.")
shuffle_test()
print("Cell 12 done ✅ — If random/prior PR-AUC is near your model, suspect misalignment or severe class skew.")


[Baseline] DTI Random PR-AUC: 0.262 (pos rate ~0.255)
[Baseline] DTI Global prior PR-AUC: 0.368
Pair overlap (train↔val): 52/22257 vs 4835 | (train↔test): 12/22257 vs 7595 | (val↔test): 8/4835 vs 7595
Drug overlap train↔val: 4/723 vs 155 | Protein overlap train↔val: 1106/2065 vs 1311
Drug overlap train↔test:1/723 vs 156 | Protein overlap train↔test:1078/2065 vs 1209
Missing embeddings → drugs: 0 rows | proteins: 0 rows
Drug emb norm: mean=12.285 std=8.090 | Protein emb norm: mean=0.997 std=5.547
Label-shuffle test: expected ROC ≈ 0.5 after retrain; run only if you plan to retrain.
Cell 12 done ✅ — If random/prior PR-AUC is near your model, suspect misalignment or severe class skew.


  return float(np.trapz(prec[order], rec[order]))


In [177]:
# === Cell 11: Final Report (DTI & ADR) — pretty print + save ===
import json
import numpy as np
import torch
from pathlib import Path

# --- small metric helpers (dup from earlier so this cell is standalone) ---
def _binary_classification_curves(y_true: np.ndarray, y_score: np.ndarray):
    order = np.argsort(-y_score)
    y_true = y_true[order]
    y_score = y_score[order]
    tp = np.cumsum(y_true == 1)
    fp = np.cumsum(y_true == 0)
    fn = tp[-1] - tp
    tn = fp[-1] - fp
    tpr = tp / np.maximum(tp[-1], 1)
    fpr = fp / np.maximum(fp[-1], 1)
    precision = tp / np.maximum(tp + fp, 1)
    recall = tp / np.maximum(tp[-1], 1)
    return fpr, tpr, precision, recall

def _auc(x: np.ndarray, y: np.ndarray) -> float:
    if len(x) < 2: return float("nan")
    order = np.argsort(x)
    return float(np.trapz(y[order], x[order]))

def roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
    fpr, tpr, *_ = _binary_classification_curves(y_true, y_score)
    return _auc(fpr, tpr)

def best_f1_threshold(y_true: np.ndarray, y_score: np.ndarray) -> float:
    # sweep on unique scores (cap to 200 evenly-spaced if huge for speed)
    uniq = np.unique(y_score)
    if uniq.size > 200:
        uniq = np.linspace(0.0, 1.0, 200)
    best_t, best_f1 = 0.5, 0.0
    for t in uniq:
        yhat = (y_score >= t).astype(np.int32)
        tp = (yhat & (y_true==1)).sum()
        fp = (yhat & (y_true==0)).sum()
        fn = ((1-yhat) & (y_true==1)).sum()
        prec = tp / (tp+fp) if (tp+fp)>0 else 0.0
        rec  = tp / (tp+fn) if (tp+fn)>0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
        if f1 > best_f1:
            best_f1, best_t = float(f1), float(t)
    return float(best_t)

def acc_f1(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> tuple[float,float]:
    acc = (y_true_bin == y_pred_bin).mean() if y_true_bin.size else float("nan")
    tp = np.logical_and(y_pred_bin==1, y_true_bin==1).sum()
    fp = np.logical_and(y_pred_bin==1, y_true_bin==0).sum()
    fn = np.logical_and(y_pred_bin==0, y_true_bin==1).sum()
    prec = tp / (tp+fp) if (tp+fp)>0 else 0.0
    rec  = tp / (tp+fn) if (tp+fn)>0 else 0.0
    f1 = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
    return float(acc), float(f1)

# --- load best threshold + temperature (if present) ---
best_val_path = OUT_DIR / "best_metrics_val.json"
calib_path = OUT_DIR / "calibration.json"
dti_thr = 0.5
if best_val_path.exists():
    try:
        dti_thr = json.loads(best_val_path.read_text())["dti_thr"]
    except Exception:
        pass
T_STAR = 1.0
if calib_path.exists():
    try:
        T_STAR = json.loads(calib_path.read_text())["temperature"]
    except Exception:
        pass

# --- helper to gather scores from a loader ---
@torch.no_grad()
def collect_dti_scores(loader):
    model.eval()
    ys, ps = [], []
    for batch in loader:
        y = batch["y"].to(DEVICE, dtype=torch.float32)
        out = model(
            x_d=batch["x_d"].to(DEVICE, non_blocking=True),
            x_p=batch["x_p"].to(DEVICE, non_blocking=True),
            t=batch["t"].to(DEVICE, non_blocking=True),
            pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
            amp_enabled=AMP_ENABLED, amp_dtype=AMP_DTYPE
        )
        prob = torch.sigmoid(out["logits"] / T_STAR).float().cpu().numpy()
        ys.append(y.cpu().numpy().astype(np.int32))
        ps.append(prob)
    if not ys: return np.array([]), np.array([])
    return np.concatenate(ys), np.concatenate(ps)

@torch.no_grad()
def collect_adr_predictions(loader):
    """
    Returns flattened arrays over all unique-drug rows:
      y_true: (U_total*K,) regression targets (TF-IDF or continuous)
      y_pred: (U_total*K,) predicted regression values
    """
    model.eval()
    Ys, Ps = [], []
    for batch in loader:
        out = model(
            x_d=batch["x_d"].to(DEVICE, non_blocking=True),
            x_p=batch["x_p"].to(DEVICE, non_blocking=True),
            t=batch["t"].to(DEVICE, non_blocking=True),
            pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
            amp_enabled=AMP_ENABLED, amp_dtype=AMP_DTYPE
        )
        y_hat = out["y_adr_hat"].detach().float().cpu().numpy()  # [U, K], >=0
        t_true = batch["t"].detach().float().cpu().numpy()       # [U, K]
        Ps.append(y_hat.reshape(-1))
        Ys.append(t_true.reshape(-1))
    if not Ys: return np.array([]), np.array([])
    return np.concatenate(Ys), np.concatenate(Ps)

# --- collect scores ---
y_val_dti, p_val_dti   = collect_dti_scores(val_loader)
y_test_dti, p_test_dti = collect_dti_scores(test_loader)

y_val_adr, p_val_adr   = collect_adr_predictions(val_loader)
y_test_adr, p_test_adr = collect_adr_predictions(test_loader)

# --- DTI metrics (classification) ---
dti_auc  = roc_auc(y_test_dti, p_test_dti) if y_test_dti.size else float("nan")
dti_pred = (p_test_dti >= float(dti_thr)).astype(np.int32) if y_test_dti.size else np.array([])
dti_acc, dti_f1 = acc_f1(y_test_dti, dti_pred) if y_test_dti.size else (float("nan"), float("nan"))

# --- ADR metrics (regression) ---
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

adr_mse = mean_squared_error(y_test_adr, p_test_adr) if y_test_adr.size else float("nan")
adr_mae = mean_absolute_error(y_test_adr, p_test_adr) if y_test_adr.size else float("nan")
adr_r2 = r2_score(y_test_adr, p_test_adr) if y_test_adr.size else float("nan")

# Optionally, per-ADR metrics (if adr_cols is available and shapes are correct)
per_adr_mse = per_adr_mae = per_adr_r2 = None
if 'adr_cols' in globals() and len(y_test_adr) % len(adr_cols) == 0:
    n_adr = len(adr_cols)
    y_true_2d = y_test_adr.reshape(-1, n_adr)
    y_pred_2d = p_test_adr.reshape(-1, n_adr)
    per_adr_mse = mean_squared_error(y_true_2d, y_pred_2d, multioutput='raw_values')
    per_adr_mae = mean_absolute_error(y_true_2d, y_pred_2d, multioutput='raw_values')
    per_adr_r2 = r2_score(y_true_2d, y_pred_2d, multioutput='raw_values')

def _fmt(x): 
    return "nan" if (x!=x) else f"{x:.4f}"

print("\nDTI PREDICTION SCORES:")
print(f"Accuracy: {_fmt(dti_acc)}")
print(f"F1-Score: {_fmt(dti_f1)}")
print(f"ROC-AUC:  {_fmt(dti_auc)}\n")

print("ADR PREDICTION REGRESSION SCORES:")
print(f"MSE: {_fmt(adr_mse)}")
print(f"MAE: {_fmt(adr_mae)}")
print(f"R2: {_fmt(adr_r2)}")
if per_adr_mse is not None:
    print("\nPer-ADR Regression Metrics:")
    for idx, col in enumerate(adr_cols):
        print(f"{col}: MSE={per_adr_mse[idx]:.4f}, MAE={per_adr_mae[idx]:.4f}, R2={per_adr_r2[idx]:.4f}")
print("==================================================")
print("Test Scores at Best Epoch:")
print(f"  DTI - Acc: {_fmt(dti_acc)}, F1: {_fmt(dti_f1)}, AUC: {_fmt(dti_auc)}")
print(f"  ADR - MSE: {_fmt(adr_mse)}, MAE: {_fmt(adr_mae)}, R2: {_fmt(adr_r2)}")

# --- save to files ---
report_txt = OUT_DIR / "final_report.txt"
report_json = OUT_DIR / "final_report.json"
with open(report_txt, "w", encoding="utf-8") as f:
    f.write(
        f"DTI PREDICTION SCORES:\n"
        f"Accuracy: {_fmt(dti_acc)}\n"
        f"F1-Score: {_fmt(dti_f1)}\n"
        f"ROC-AUC:  {_fmt(dti_auc)}\n\n"
        f"ADR PREDICTION REGRESSION SCORES:\n"
        f"MSE: {_fmt(adr_mse)}\n"
        f"MAE: {_fmt(adr_mae)}\n"
        f"R2: {_fmt(adr_r2)}\n"
        + ("\nPer-ADR Regression Metrics:\n" + "\n".join(f"{col}: MSE={per_adr_mse[idx]:.4f}, MAE={per_adr_mae[idx]:.4f}, R2={per_adr_r2[idx]:.4f}" for idx, col in enumerate(adr_cols)) if per_adr_mse is not None else "")
        + "\n==================================================\n"
        f"Test Scores at Best Epoch:\n"
        f"  DTI - Acc: {_fmt(dti_acc)}, F1: {_fmt(dti_f1)}, AUC: {_fmt(dti_auc)}\n"
        f"  ADR - MSE: {_fmt(adr_mse)}, MAE: {_fmt(adr_mae)}, R2: {_fmt(adr_r2)}\n"
    )
json.dump({
    "DTI": {"accuracy": dti_acc, "f1": dti_f1, "roc_auc": dti_auc, "threshold": float(dti_thr), "temperature": float(T_STAR)},
    "ADR": {"mse": adr_mse, "mae": adr_mae, "r2": adr_r2},
    "ADR_per_label": {col: {"mse": float(per_adr_mse[idx]), "mae": float(per_adr_mae[idx]), "r2": float(per_adr_r2[idx])} for idx, col in enumerate(adr_cols)} if per_adr_mse is not None else {},
}, open(report_json, "w"), indent=2)

print(f"\nSaved report to:\n - {report_txt}\n - {report_json}")



DTI PREDICTION SCORES:
Accuracy: 0.8036
F1-Score: 0.5795
ROC-AUC:  0.7983

ADR PREDICTION REGRESSION SCORES:
MSE: 0.0005
MAE: 0.0107
R2: -0.8845

Per-ADR Regression Metrics:
meddra_10000045: MSE=0.0005, MAE=0.0122, R2=-6.7761
meddra_10000050: MSE=0.0002, MAE=0.0090, R2=0.0000
meddra_10000054: MSE=0.0001, MAE=0.0048, R2=0.0000
meddra_10000055: MSE=0.0002, MAE=0.0077, R2=-7.8035
meddra_10000056: MSE=0.0008, MAE=0.0154, R2=0.0000
meddra_10000057: MSE=0.0002, MAE=0.0082, R2=0.0000
meddra_10000059: MSE=0.0005, MAE=0.0146, R2=-0.4930
meddra_10000060: MSE=0.0009, MAE=0.0149, R2=-0.1674
meddra_10000062: MSE=0.0002, MAE=0.0088, R2=0.0000
meddra_10000064: MSE=0.0003, MAE=0.0094, R2=0.0000
meddra_10000065: MSE=0.0003, MAE=0.0106, R2=-11.9010
meddra_10000081: MSE=0.0033, MAE=0.0340, R2=-0.1940
meddra_10000084: MSE=0.0005, MAE=0.0126, R2=-1.1637
meddra_10000085: MSE=0.0003, MAE=0.0113, R2=-8.5521
meddra_10000087: MSE=0.0016, MAE=0.0169, R2=-0.0914
meddra_10000097: MSE=0.0003, MAE=0.0106, R2=-4.573

  return float(np.trapz(y[order], x[order]))


Cell 6.5 — Calibration

In [178]:
# === Cell 6.5: Temperature scaling on VAL logits ===
import torch
import numpy as np

model.eval()
# Collect raw logits and labels on VAL
logits_val, labels_val = [], []
with torch.no_grad():
    for batch in val_loader:
        out = model(
            x_d=batch["x_d"].to(DEVICE),
            x_p=batch["x_p"].to(DEVICE),
            t=batch["t"].to(DEVICE),
            pair_to_u=batch["pair_to_u"].to(DEVICE),
            amp_enabled=AMP_ENABLED, amp_dtype=AMP_DTYPE
        )
        logits_val.append(out["logits"].detach().float().cpu())
        labels_val.append(batch["y"].detach().float().cpu())
logits_val = torch.cat(logits_val)  # [N]
labels_val = torch.cat(labels_val)  # [N]

# Optimize a scalar temperature T ≥ 0.5 for numerical stability
T = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
opt = torch.optim.LBFGS([T], lr=0.1, max_iter=50, line_search_fn="strong_wolfe")

def _nll():
    opt.zero_grad()
    p = torch.sigmoid(logits_val / torch.clamp(T, min=0.5))
    # Binary NLL
    loss = -(labels_val*torch.log(p+1e-9) + (1-labels_val)*torch.log(1-p+1e-9)).mean()
    loss.backward()
    return loss

opt.step(_nll)
T_star = float(T.item())
print(f"Calibrated temperature T* = {T_star:.3f}")

# Save for inference/eval
with open(OUT_DIR / "calibration.json", "w") as f:
    json.dump({"temperature": T_star}, f)


Calibrated temperature T* = 2.394


Cell 7 — Inference utilities — p(bind) + top-k ADRs using prototypes

In [374]:
# === Cell 7 (updated): Inference + ADR names via MedDRA mapping parquet in Model_v1/ ===

import numpy as np
import torch
import pandas as pd
from pathlib import Path

# --- Load best checkpoint ---
model_ckpt_path = OUT_DIR / "best.pt"
if not model_ckpt_path.exists():
    raise FileNotFoundError(f"Best checkpoint not found at {model_ckpt_path}")

ckpt = torch.load(model_ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

# --- Helper: find the MedDRA mapping parquet at Model_v1 root ---
# It must contain the 3 columns:
#   rxnorm_ingredient_id (object), meddra_id (int), meddra_name (object)
def find_meddra_map_parquet(nb_root: Path) -> Path | None:
    for p in nb_root.glob("*.parquet"):
        try:
            df = pd.read_parquet(p, columns=["rxnorm_ingredient_id", "meddra_id", "meddra_name"])
            # lightweight schema verification
            if {"rxnorm_ingredient_id","meddra_id","meddra_name"}.issubset(df.columns):
                return p
        except Exception:
            continue
    return None

# meddra_map_path = find_meddra_map_parquet(NB_ROOT)
meddra_map_path = "final_rxnorm_meddra_v2.parquet"

print(meddra_map_path)
if meddra_map_path is None:
    print("⚠️ Could not find a MedDRA mapping parquet at Model_v1/*.parquet "
          "with columns [rxnorm_ingredient_id, meddra_id, meddra_name]. "
          "ADR names will fall back to idf_table’s first column.")
    meddra_map_df = None
else:
    meddra_map_df = pd.read_parquet(meddra_map_path, columns=["meddra_id","meddra_name"]).drop_duplicates()
    print(f"Loaded MedDRA name map: {meddra_map_path}  (rows={len(meddra_map_df)})")

# --- Load ADR prototypes (or build) ---
proto_path = Path(cfg.get("artifacts", {}).get("prototypes_path", OUT_DIR / "prototypes_C.npy"))
if proto_path.exists():
    C = np.load(proto_path)
else:
    with torch.no_grad():
        W = model.f_a(T_train.to(DEVICE, dtype=torch.float32))
        W = W.detach().cpu().numpy()  # [U_train, d]
        T_np = T_train.cpu().numpy()  # [U_train, K]
        K = T_np.shape[1]; d = W.shape[1]
        C = np.zeros((K, d), dtype=np.float32)
        for k in range(K):
            mask = T_np[:, k] > 0
            C[k] = W[mask].mean(axis=0) if mask.any() else 0.0
np.testing.assert_equal(C.shape[0], int(cfg["data"]["K"]))
C_tensor = torch.from_numpy(C).to(DEVICE, dtype=torch.float32)  # [K, d]

# --- Build ADR ID & Name lists (from idf_table + optional MedDRA join) ---
idf_table_path = (NB_ROOT / cfg["data"]["adr_root"] / "idf_table.parquet").resolve()
idf_table = pd.read_parquet(idf_table_path)

# Heuristic: first column is the ADR identifier (often MedDRA ID or term)
ADR_ID_COL = idf_table.columns[0]
adr_ids_raw = idf_table[ADR_ID_COL]

# Try to coerce to int MedDRA IDs (safe if already strings of ints)
adr_ids_int = None
try:
    adr_ids_int = pd.to_numeric(adr_ids_raw, errors="raise").astype("int64")
except Exception:
    # Not numeric; leave None
    pass

if meddra_map_df is not None and adr_ids_int is not None:
    # Join on meddra_id to get meddra_name
    names_df = pd.DataFrame({"meddra_id": adr_ids_int})
    names_df = names_df.merge(meddra_map_df, on="meddra_id", how="left")
    ADR_NAMES = names_df["meddra_name"].fillna(adr_ids_raw.astype(str)).astype(str).tolist()
    ADR_IDS_DISPLAY = adr_ids_int.astype(int).tolist()
else:
    # Fallback to whatever the idf_table’s first column provides
    ADR_NAMES = adr_ids_raw.astype(str).tolist()
    ADR_IDS_DISPLAY = adr_ids_raw.astype(str).tolist()

assert len(ADR_NAMES) == C.shape[0], "ADR names length != K"

# --- Scoring weights for ADR ranking ---
alpha = float(cfg["model"]["adr_scoring"]["alpha"])
beta  = float(cfg["model"]["adr_scoring"]["beta"])
gamma = float(cfg["model"]["adr_scoring"]["gamma"])

@torch.no_grad()
def _encode_drug_protein(drug_id: str, prot_id: str) -> tuple[torch.Tensor, torch.Tensor]:
    if drug_id not in DRUG_ID2IDX:
        raise KeyError(f"Unknown drug_chembl_id: {drug_id}")
    if prot_id not in PROT_ID2IDX:
        raise KeyError(f"Unknown protein id: {prot_id}")

    di = DRUG_ID2IDX[drug_id]
    pi = PROT_ID2IDX[prot_id]

    x_d = DRUG_TENSOR[di:di+1].to(DEVICE, dtype=torch.float32)
    x_p = PROT_TENSOR[pi:pi+1].to(DEVICE, dtype=torch.float32)

    u = model.f_d(x_d)  # [1, d]
    v = model.f_p(x_p)  # [1, d]
    return u, v

@torch.no_grad()
def predict_pair(
    drug_chembl_id: str,
    protein_id: str,
    topk: int = 10,
    return_dataframe: bool = True
):
    """
    Returns:
      dict with p_bind, threshold_used, topk DataFrame:
      columns = [rank, adr_id, adr_name, score]
    """
    u, v = _encode_drug_protein(drug_chembl_id, protein_id)  # [1,d] each

    # Binding probability
    logits = model.h_dti(u, v)             # [1]
    # p_bind = torch.sigmoid(logits).item()
    
    # load T once near top of Cell 7
    calib_path = OUT_DIR / "calibration.json"
    T_STAR = json.loads(calib_path.read_text())["temperature"] if calib_path.exists() else 1.0

    # inside predict_pair (after computing logits)
    p_bind = torch.sigmoid(logits / T_STAR).item()


    # Pair-conditioned ADR score:
    # score_k = sigmoid( α*(u·Cᵀ) + β*(v·Cᵀ) + γ*(u·v) )
    u_n = u / (u.norm(p=2, dim=1, keepdim=True) + 1e-8)
    v_n = v / (v.norm(p=2, dim=1, keepdim=True) + 1e-8)
    C_n = C_tensor / (C_tensor.norm(p=1, dim=1, keepdim=False).unsqueeze(1) + 1e-8)

    s_u = torch.matmul(u_n, C_n.T)  # [1, K]
    s_v = torch.matmul(v_n, C_n.T)  # [1, K]
    pair_sim = torch.sum(u_n * v_n, dim=1, keepdim=True)  # [1,1]
    scores = torch.sigmoid(alpha * s_u + beta * s_v + gamma * pair_sim).squeeze(0)  # [K]

    # Top-k ADRs
    topk = int(min(topk, scores.numel()))
    idx = torch.topk(scores, k=topk, largest=True, sorted=True).indices.detach().cpu().numpy()
    vals = scores.detach().cpu().numpy()[idx]

    # Build output with names
    adr_ids_sel   = [ADR_IDS_DISPLAY[i] for i in idx]
    adr_names_sel = [ADR_NAMES[i] for i in idx]

    if return_dataframe:
        df = pd.DataFrame({
            "rank": np.arange(1, topk+1, dtype=int),
            "adr_id": adr_ids_sel,
            "adr_name": adr_names_sel,
            "score": vals
        })
        print(f"Drug={drug_chembl_id} | Protein={protein_id}")
        print(f"p_bind = {p_bind:.4f}  (val-threshold* ≈ {ckpt.get('best_thr', np.nan):.3f})")
        display(df)
        return {
            "p_bind": p_bind,
            "threshold_used": float(ckpt.get("best_thr", np.nan)),
            "topk": df
        }
    else:
        return {
            "p_bind": p_bind,
            "threshold_used": float(ckpt.get("best_thr", np.nan)),
            "topk_idx": idx,
            "topk_scores": vals,
            "topk_adr_ids": adr_ids_sel,
            "topk_adr_names": adr_names_sel
        }

print("Inference API (with ADR names) ready. Use:")
_ = predict_pair('CHEMBL1009', 'Q9UHI5', topk=10, return_dataframe=True)


final_rxnorm_meddra_v2.parquet
Loaded MedDRA name map: final_rxnorm_meddra_v2.parquet  (rows=4817)
Inference API (with ADR names) ready. Use:
Drug=CHEMBL1009 | Protein=Q9UHI5
p_bind = 0.8315  (val-threshold* ≈ 0.752)


Unnamed: 0,rank,adr_id,adr_name,score
0,1,10007554,Cardiac failure,0.511904
1,2,10002959,Aphthous ulcer,0.511389
2,3,10002967,Aplastic anaemia,0.510882
3,4,10067125,Liver injury,0.510694
4,5,10038923,Retinopathy,0.509082
5,6,10054524,Palmar-plantar erythrodysesthesia syndrome,0.509058
6,7,10010770,Consciousness disturbed,0.508996
7,8,10059206,Nail toxicity,0.508899
8,9,10047862,Weakness,0.508505
9,10,10012536,Detachment,0.508493


In [196]:
# === Cell 7: Inference utilities — p(bind) + top-k ADRs using prototypes ===

import numpy as np
import torch
import pandas as pd
from pathlib import Path

model_ckpt_path = OUT_DIR / "best.pt"
if not model_ckpt_path.exists():
    raise FileNotFoundError(f"Best checkpoint not found at {model_ckpt_path}")

ckpt = torch.load(model_ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

# Load ADR prototypes (or build from train split if not present)
proto_path = Path(cfg.get("artifacts", {}).get("prototypes_path", OUT_DIR / "prototypes_C.npy"))
if proto_path.exists():
    C = np.load(proto_path)
else:
    # Fallback: compute from TRAIN split quickly (same logic as in Cell 6)
    with torch.no_grad():
        W = model.f_a(T_train.to(DEVICE, dtype=torch.float32))
        W = W.detach().cpu().numpy()  # [U_train, d]
        T_np = T_train.cpu().numpy()  # [U_train, K]
        K = T_np.shape[1]; d = W.shape[1]
        C = np.zeros((K, d), dtype=np.float32)
        for k in range(K):
            mask = T_np[:, k] > 0
            C[k] = W[mask].mean(axis=0) if mask.any() else 0.0
np.testing.assert_equal(C.shape[0], int(cfg["data"]["K"]))
C_tensor = torch.from_numpy(C).to(DEVICE, dtype=torch.float32)  # [K, d]

# ADR labels/IDs for display
idf_table_path = (NB_ROOT / cfg["data"]["adr_root"] / "idf_table.parquet").resolve()
idf_table = pd.read_parquet(idf_table_path)
# Heuristic: use the first column as ADR key/name
ADR_NAME_COL = idf_table.columns[0]
ADR_NAMES = idf_table[ADR_NAME_COL].astype(str).tolist()
assert len(ADR_NAMES) == C.shape[0], "ADR names length != K"

# Scoring weights for ADR ranking
alpha = float(cfg["model"]["adr_scoring"]["alpha"])
beta  = float(cfg["model"]["adr_scoring"]["beta"])
gamma = float(cfg["model"]["adr_scoring"]["gamma"])

@torch.no_grad()
def _encode_drug_protein(drug_id: str, prot_id: str) -> tuple[torch.Tensor, torch.Tensor]:
    if drug_id not in DRUG_ID2IDX:
        raise KeyError(f"Unknown drug_chembl_id: {drug_id}")
    if prot_id not in PROT_ID2IDX:
        raise KeyError(f"Unknown protein id: {prot_id}")

    di = DRUG_ID2IDX[drug_id]
    pi = PROT_ID2IDX[prot_id]

    x_d = DRUG_TENSOR[di:di+1].to(DEVICE, dtype=torch.float32)
    x_p = PROT_TENSOR[pi:pi+1].to(DEVICE, dtype=torch.float32)

    # We don’t need TF-IDF here for inference; pass a dummy minimal tensor to satisfy API?
    # The forward() requires t and pair_to_u; we’ll bypass heads and call adapters directly for speed.
    u = model.f_d(x_d)  # [1, d]
    v = model.f_p(x_p)  # [1, d]
    return u, v

@torch.no_grad()
def predict_pair(
    drug_chembl_id: str,
    protein_id: str,
    topk: int = 10,
    return_dataframe: bool = True
):
    """
    Returns:
      dict with p_bind, threshold_used, topk DataFrame (ADR, score) or numpy arrays.
    """
    u, v = _encode_drug_protein(drug_chembl_id, protein_id)  # [1,d] each
    # Binding probability from the selected head
    if model.dti_head_type == "cosine":
        logits = model.h_dti(u, v)           # [1]
    else:
        logits = model.h_dti(u, v)           # [1]
    p_bind = torch.sigmoid(logits).item()

    # Pair-conditioned ADR score:
    # score_k = sigmoid( α*(u·Cᵀ) + β*(v·Cᵀ) + γ*(u·v) )  (broadcast γ term)
    u_n = u / (u.norm(p=2, dim=1, keepdim=True) + 1e-8)
    v_n = v / (v.norm(p=2, dim=1, keepdim=True) + 1e-8)
    C_n = C_tensor / (C_tensor.norm(p=1, dim=1, keepdim=False).unsqueeze(1) + 1e-8)  # light norm for stability

    s_u = torch.matmul(u_n, C_n.T)  # [1, K]
    s_v = torch.matmul(v_n, C_n.T)  # [1, K]
    pair_sim = torch.sum(u_n * v_n, dim=1, keepdim=True)  # [1,1]
    scores = torch.sigmoid(alpha * s_u + beta * s_v + gamma * pair_sim).squeeze(0)  # [K]

    # Top-k ADRs
    topk = int(min(topk, scores.numel()))
    idx = torch.topk(scores, k=topk, largest=True, sorted=True).indices.detach().cpu().numpy()
    vals = scores.detach().cpu().numpy()[idx]

    if return_dataframe:
        df = pd.DataFrame({
            "rank": np.arange(1, topk+1, dtype=int),
            "adr":  [ADR_NAMES[i] for i in idx],
            "score": vals
        })
        # pretty print
        print(f"Drug={drug_chembl_id} | Protein={protein_id}")
        print(f"p_bind = {p_bind:.4f}  (val-threshold* ≈ {ckpt.get('best_thr', np.nan):.3f})")
        display(df)
        return {
            "p_bind": p_bind,
            "threshold_used": float(ckpt.get("best_thr", np.nan)),
            "topk": df
        }
    else:
        return {
            "p_bind": p_bind,
            "threshold_used": float(ckpt.get("best_thr", np.nan)),
            "topk_idx": idx,
            "topk_scores": vals
        }

# --------- Quick demo (edit IDs as needed) ----------
# Example: pick any valid IDs from DRUG_ID_LIST / PROT_ID_LIST
# _ = predict_pair(DRUG_ID_LIST[0], PROT_ID_LIST[0], topk=10, return_dataframe=True)
_ = predict_pair("CHEMBL1009", "P21918", topk=10, return_dataframe=True)

# print("Inference API ready. Use: predict_pair('<drug_chembl_id>', '<protein_id>', topk=10)")


Drug=CHEMBL1009 | Protein=P21918
p_bind = 0.6884  (val-threshold* ≈ 0.605)


Unnamed: 0,rank,adr,score
0,1,10024492,0.521339
1,2,10018473,0.52112
2,3,10018232,0.521076
3,4,10074859,0.520607
4,5,10043088,0.520416
5,6,10020916,0.520355
6,7,10020915,0.520355
7,8,10003458,0.520301
8,9,10054209,0.520201
9,10,10012703,0.520176


Cell 8 — Export compact deployment bundle

In [197]:
# === Cell 8: Export compact deployment bundle ===
from pathlib import Path
import json
import torch
import numpy as np
import shutil
import pandas as pd

bundle_dir = OUT_DIR / "deploy_bundle"
bundle_dir.mkdir(parents=True, exist_ok=True)

# ---- 1. Save model weights (state_dict only) ----
weights_path = bundle_dir / "model_state.pt"
torch.save(model.state_dict(), weights_path)

# ---- 2. Copy config & best metrics ----
shutil.copy2(CONFIG_PATH, bundle_dir / "config.yaml")
best_metrics_src = OUT_DIR / "best_metrics_val.json"
if best_metrics_src.exists():
    shutil.copy2(best_metrics_src, bundle_dir / "best_metrics_val.json")

# ---- 3. Save ADR prototypes ----
proto_path = Path(cfg.get("artifacts", {}).get("prototypes_path", OUT_DIR / "prototypes_C.npy"))
if proto_path.exists():
    shutil.copy2(proto_path, bundle_dir / "prototypes_C.npy")
else:
    # quick rebuild from train split if missing
    with torch.no_grad():
        W = model.f_a(T_train.to(DEVICE, dtype=torch.float32))
        W = W.detach().cpu().numpy()
        T_np = T_train.cpu().numpy()
        K = T_np.shape[1]; d = W.shape[1]
        C = np.zeros((K, d), dtype=np.float32)
        for k in range(K):
            mask = T_np[:, k] > 0
            C[k] = W[mask].mean(axis=0) if mask.any() else 0.0
        np.save(bundle_dir / "prototypes_C.npy", C)

# ---- 4. Save ADR label names ----
idf_table_path = (NB_ROOT / cfg["data"]["adr_root"] / "idf_table.parquet").resolve()
idf_table = pd.read_parquet(idf_table_path)
adr_name_col = idf_table.columns[0]
idf_table[[adr_name_col]].to_csv(bundle_dir / "adr_labels.csv", index=False)

# ---- 5. Save a minimal inference script ----
inference_py = bundle_dir / "predict_pair_minimal.py"
inference_py.write_text(
f"""\
import torch, numpy as np, pandas as pd, json
from pathlib import Path

def load_bundle(bundle_dir: str):
    p = Path(bundle_dir)
    cfg = yaml.safe_load(open(p/'config.yaml'))
    model_state = torch.load(p/'model_state.pt', map_location='cpu')
    C = np.load(p/'prototypes_C.npy')
    adr_labels = pd.read_csv(p/'adr_labels.csv')[{repr(adr_name_col)}].tolist()
    return cfg, model_state, C, adr_labels

# Usage example:
# cfg, state, C, labels = load_bundle('runs/dti_adr_v1/deploy_bundle')
# print('Loaded bundle with', len(labels), 'ADRs')
"""
)

print(f"Deployment bundle ready at: {bundle_dir}")
print(f"Contents:\n- model_state.pt\n- config.yaml\n- best_metrics_val.json\n- prototypes_C.npy\n- adr_labels.csv\n- predict_pair_minimal.py")


Deployment bundle ready at: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs\dti_adr_v1\deploy_bundle
Contents:
- model_state.pt
- config.yaml
- best_metrics_val.json
- prototypes_C.npy
- adr_labels.csv
- predict_pair_minimal.py


Cell 9 — Plot & save metrics for this run (loss curves, PR/ROC, ADR, top-k) + run_summary.csv

In [198]:
# === Cell 9: Visualize & Save Metrics for THIS run ===
import json, time
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------- Load history / best / test ----------
hist_path = OUT_DIR / "history.json"
best_val_path = OUT_DIR / "best_metrics_val.json"
test_path = OUT_DIR / "test_metrics.json"

if not hist_path.exists():
    raise FileNotFoundError(f"Missing {hist_path}. Train first (Cell 6).")

history = json.loads(Path(hist_path).read_text(encoding="utf-8"))
best_val = json.loads(Path(best_val_path).read_text(encoding="utf-8")) if best_val_path.exists() else {}
m_test   = json.loads(Path(test_path).read_text(encoding="utf-8")) if test_path.exists() else {}

# Extract per-epoch series
epochs = [h["epoch"] for h in history]
L_total = [h["train"]["L_total"] for h in history]
L_dti   = [h["train"]["L_dti"] for h in history]
L_adr   = [h["train"]["L_adr"] for h in history]
L_con   = [h["train"]["L_con"] for h in history]

val_pr  = [h["val"].get("dti_pr_auc", np.nan) for h in history]
val_roc = [h["val"].get("dti_roc_auc", np.nan) for h in history]
val_f1  = [h["val"].get("dti_f1", np.nan) for h in history]
val_thr = [h["val"].get("dti_thr", np.nan) for h in history]
val_rmse= [h["val"].get("adr_rmse", np.nan) for h in history]
val_mae = [h["val"].get("adr_mae", np.nan) for h in history]

# ---------- Helper: recompute PR & ROC curves for nice plots ----------
def collect_scores(loader):
    model.eval()
    y_all, p_all = [], []
    with torch.no_grad():
        for batch in loader:
            y = batch["y"].to(DEVICE, dtype=torch.float32, non_blocking=True)
            out = model(
                x_d=batch["x_d"].to(DEVICE, non_blocking=True),
                x_p=batch["x_p"].to(DEVICE, non_blocking=True),
                t=batch["t"].to(DEVICE, non_blocking=True),
                pair_to_u=batch["pair_to_u"].to(DEVICE, non_blocking=True),
                amp_enabled=AMP_ENABLED,
                amp_dtype=AMP_DTYPE
            )
            prob = torch.sigmoid(out["logits"]).detach().float().cpu().numpy()
            y_all.append(y.detach().cpu().numpy().astype(np.int32))
            p_all.append(prob)
    if len(y_all) == 0:
        return np.array([]), np.array([])
    return np.concatenate(y_all), np.concatenate(p_all)

def pr_curve_points(y, p):
    # sort by p desc
    order = np.argsort(-p)
    y = y[order]; p = p[order]
    tp = np.cumsum(y == 1)
    fp = np.cumsum(y == 0)
    precision = tp / np.maximum(tp + fp, 1)
    recall    = tp / np.maximum(tp[-1] if tp.size else 1, 1)
    return recall, precision

def roc_curve_points(y, p):
    # sort by p desc
    order = np.argsort(-p)
    y = y[order]; p = p[order]
    tp = np.cumsum(y == 1)
    fp = np.cumsum(y == 0)
    fn = (y == 1).sum() - tp
    tn = (y == 0).sum() - fp
    tpr = tp / np.maximum((tp+fn), 1)
    fpr = fp / np.maximum((fp+tn), 1)
    return fpr, tpr

# Collect scores for val & test to draw smooth curves
y_val, p_val = collect_scores(val_loader)
y_tst, p_tst = collect_scores(test_loader)

# ---------- Plot 1: Training loss curves ----------
plt.figure(figsize=(7,4.5))
plt.plot(epochs, L_total, label="Total")
plt.plot(epochs, L_dti,   label="DTI")
plt.plot(epochs, L_adr,   label="ADR")
plt.plot(epochs, L_con,   label="Contrastive")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training Loss Curves")
plt.legend(); plt.tight_layout()
plt.savefig(OUT_DIR / "plot_loss_curves.png", dpi=150)
plt.close()

# ---------- Plot 2: Validation DTI metrics by epoch ----------
plt.figure(figsize=(7,4.5))
plt.plot(epochs, val_pr,  label="PR-AUC")
plt.plot(epochs, val_roc, label="ROC-AUC")
plt.plot(epochs, val_f1,  label="F1")
plt.xlabel("Epoch"); plt.ylabel("Score"); plt.title("Validation DTI Metrics")
plt.legend(); plt.tight_layout()
plt.savefig(OUT_DIR / "plot_val_dti_metrics.png", dpi=150)
plt.close()

# ---------- Plot 3: Validation ADR errors (RMSE/MAE) ----------
plt.figure(figsize=(7,4.5))
plt.plot(epochs, val_rmse, label="ADR RMSE")
plt.plot(epochs, val_mae,  label="ADR MAE")
plt.xlabel("Epoch"); plt.ylabel("Error"); plt.title("Validation ADR Errors")
plt.legend(); plt.tight_layout()
plt.savefig(OUT_DIR / "plot_val_adr_errors.png", dpi=150)
plt.close()

# ---------- Plot 4: PR & ROC curves (val & test) ----------
if y_val.size > 0 and y_tst.size > 0:
    rv, pv = pr_curve_points(y_val, p_val)
    rt, pt = pr_curve_points(y_tst, p_tst)
    plt.figure(figsize=(6.2,4.5))
    plt.plot(rv, pv, label=f"VAL (PR-AUC={best_val.get('dti_pr_auc', np.nan):.3f})")
    plt.plot(rt, pt, label=f"TEST (PR-AUC={m_test.get('dti_pr_auc', np.nan):.3f})")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("Precision–Recall")
    plt.legend(); plt.tight_layout()
    plt.savefig(OUT_DIR / "plot_pr_curves.png", dpi=150)
    plt.close()

    fv, tv = roc_curve_points(y_val, p_val)
    ft, tt = roc_curve_points(y_tst, p_tst)
    plt.figure(figsize=(6.2,4.5))
    plt.plot(fv, tv, label=f"VAL (ROC-AUC={best_val.get('dti_roc_auc', np.nan):.3f})")
    plt.plot(ft, tt, label=f"TEST (ROC-AUC={m_test.get('dti_roc_auc', np.nan):.3f})")
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC")
    plt.legend(); plt.tight_layout()
    plt.savefig(OUT_DIR / "plot_roc_curves.png", dpi=150)
    plt.close()

# ---------- Plot 5: Top-k ranking (if recorded in best/test) ----------
def _maybe_plot_topk(prefix: str, metrics: dict):
    ks = [k for k in metrics.keys() if k.startswith("recall@") or k.startswith("ndcg@")]
    if not ks:
        return
    # group into recall and ndcg
    rec = sorted([(int(k.split("@")[1]), metrics[k]) for k in ks if k.startswith("recall@")])
    ndc = sorted([(int(k.split("@")[1]), metrics[k]) for k in ks if k.startswith("ndcg@")])
    if rec:
        plt.figure(figsize=(6,4))
        plt.bar([f"@{k}" for k,_ in rec], [v for _,v in rec])
        plt.ylim(0,1)
        plt.title(f"{prefix} ADR Recall@k"); plt.tight_layout()
        plt.savefig(OUT_DIR / f"plot_{prefix.lower()}_recall_at_k.png", dpi=150)
        plt.close()
    if ndc:
        plt.figure(figsize=(6,4))
        plt.bar([f"@{k}" for k,_ in ndc], [v for _,v in ndc])
        plt.ylim(0,1)
        plt.title(f"{prefix} ADR NDCG@k"); plt.tight_layout()
        plt.savefig(OUT_DIR / f"plot_{prefix.lower()}_ndcg_at_k.png", dpi=150)
        plt.close()

if best_val: _maybe_plot_topk("VAL", best_val)
if m_test:   _maybe_plot_topk("TEST", m_test)

print("Saved plots to:", OUT_DIR)
for p in ["plot_loss_curves.png","plot_val_dti_metrics.png","plot_val_adr_errors.png","plot_pr_curves.png","plot_roc_curves.png","plot_VAL_recall_at_k.png","plot_VAL_ndcg_at_k.png","plot_TEST_recall_at_k.png","plot_TEST_ndcg_at_k.png"]:
    q = OUT_DIR / p
    if q.exists():
        print(" -", q.name)

# ---------- Append run_summary.csv ----------
summary_row = {
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    "run_name": cfg["run"]["name"],
    "drug_encoder": cfg["model"]["drug_encoder"],
    "protein_encoder": cfg["model"]["protein_encoder"],
    "shared_dim": cfg["model"]["shared_dim"],
    "dti_head": cfg["model"]["dti_head"],
    "val_pr_auc": best_val.get("dti_pr_auc", np.nan),
    "val_roc_auc": best_val.get("dti_roc_auc", np.nan),
    "val_f1": best_val.get("dti_f1", np.nan),
    "val_thr": best_val.get("dti_thr", np.nan),
    "val_adr_rmse": best_val.get("adr_rmse", np.nan),
    "val_adr_mae": best_val.get("adr_mae", np.nan),
    "test_pr_auc": m_test.get("dti_pr_auc", np.nan),
    "test_roc_auc": m_test.get("dti_roc_auc", np.nan),
    "test_f1": m_test.get("dti_f1", np.nan),
    "test_adr_rmse": m_test.get("adr_rmse", np.nan),
    "test_adr_mae": m_test.get("adr_mae", np.nan),
}

summary_csv = OUT_DIR / "run_summary.csv"
df_row = pd.DataFrame([summary_row])
if summary_csv.exists():
    df_old = pd.read_csv(summary_csv)
    df_out = pd.concat([df_old, df_row], ignore_index=True)
else:
    df_out = df_row
df_out.to_csv(summary_csv, index=False)
print("Updated:", summary_csv)


Saved plots to: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs\dti_adr_v1
 - plot_loss_curves.png
 - plot_val_dti_metrics.png
 - plot_val_adr_errors.png
 - plot_pr_curves.png
 - plot_roc_curves.png
 - plot_VAL_recall_at_k.png
 - plot_VAL_ndcg_at_k.png
 - plot_TEST_recall_at_k.png
 - plot_TEST_ndcg_at_k.png
Updated: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs\dti_adr_v1\run_summary.csv


Cell 10 — Compare multiple runs (build a cross-run table + small plots)

In [347]:
# === Cell 10: Multi-run aggregator for comparison across encoder combos ===
from pathlib import Path
import pandas as pd
import json
import matplotlib.pyplot as plt

RUNS_DIR = (NB_ROOT / "runs").resolve()
rows = []
for run_dir in RUNS_DIR.glob("*"):
    if not run_dir.is_dir():
        continue
    hist = run_dir / "history.json"
    best = run_dir / "best_metrics_val.json"
    test = run_dir / "test_metrics.json"
    cfgp = run_dir.parent / "config.yaml"  # often copied in deploy_bundle; we also have OUT_DIR/config.yaml
    # Accept missing cfg; pull encoders from checkpoint config if present
    meta = {"run_path": str(run_dir)}
    try:
        b = json.loads(best.read_text(encoding="utf-8")) if best.exists() else {}
        t = json.loads(test.read_text(encoding="utf-8")) if test.exists() else {}
        # Try to read resolved_config.json for encoders
        rc = run_dir / "resolved_config.json"
        enc = {}
        if rc.exists():
            rcj = json.loads(rc.read_text(encoding="utf-8"))
            enc["drug_encoder"] = rcj.get("model", {}).get("drug_encoder")
            enc["protein_encoder"] = rcj.get("model", {}).get("protein_encoder")
            enc["dti_head"] = rcj.get("model", {}).get("dti_head")
            enc["shared_dim"] = rcj.get("model", {}).get("shared_dim")
        meta.update(enc)
        meta.update({
            "val_pr_auc": b.get("dti_pr_auc", float("nan")),
            "val_roc_auc": b.get("dti_roc_auc", float("nan")),
            "val_f1": b.get("dti_f1", float("nan")),
            "val_thr": b.get("dti_thr", float("nan")),
            "val_adr_rmse": b.get("adr_rmse", float("nan")),
            "val_adr_mae": b.get("adr_mae", float("nan")),
            "test_pr_auc": t.get("dti_pr_auc", float("nan")),
            "test_roc_auc": t.get("dti_roc_auc", float("nan")),
            "test_f1": t.get("dti_f1", float("nan")),
            "test_adr_rmse": t.get("adr_rmse", float("nan")),
            "test_adr_mae": t.get("adr_mae", float("nan")),
        })
        rows.append(meta)
    except Exception:
        continue

if not rows:
    raise RuntimeError("No completed runs found in ./runs/*")

df_runs = pd.DataFrame(rows)
display(df_runs.sort_values(["val_pr_auc","test_pr_auc"], ascending=False).reset_index(drop=True))

# Simple comparison plots (val PR-AUC & test PR-AUC by encoder combo)
def _label_combo(r):
    return f"{r.get('drug_encoder','?')}/{r.get('protein_encoder','?')}:{r.get('dti_head','?')}"

df_runs["combo"] = df_runs.apply(_label_combo, axis=1)

plt.figure(figsize=(8,4.5))
plt.bar(df_runs["combo"], df_runs["val_pr_auc"])
plt.xticks(rotation=30, ha="right")
plt.ylabel("PR-AUC (val)"); plt.title("Validation PR-AUC by Encoder Combo")
plt.tight_layout(); plt.savefig(RUNS_DIR / "compare_val_pr_auc.png", dpi=150); plt.close()

plt.figure(figsize=(8,4.5))
plt.bar(df_runs["combo"], df_runs["test_pr_auc"])
plt.xticks(rotation=30, ha="right")
plt.ylabel("PR-AUC (test)"); plt.title("Test PR-AUC by Encoder Combo")
plt.tight_layout(); plt.savefig(RUNS_DIR / "compare_test_pr_auc.png", dpi=150); plt.close()

print("Saved comparison charts to:", RUNS_DIR)


Unnamed: 0,run_path,drug_encoder,protein_encoder,dti_head,shared_dim,val_pr_auc,val_roc_auc,val_f1,val_thr,val_adr_rmse,val_adr_mae,test_pr_auc,test_roc_auc,test_f1,test_adr_rmse,test_adr_mae
0,F:\Thesis Korbi na\dti-prediction-with-adr\Mod...,chemberta,esm,cosine,512,0.787983,0.866298,0.752743,0.594831,0.080055,0.066179,0.680744,0.830723,0.442185,0.035739,0.026058


Saved comparison charts to: F:\Thesis Korbi na\dti-prediction-with-adr\Model_v1\runs
