In [7]:
# ===== Part 0: setup & config loader (no auto-fix) =====

from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, Optional, List, Tuple
import json
import os
import sys

import numpy as np
import pandas as pd

# If pyyaml isn't installed, run:  !pip install pyyaml
import yaml

In [8]:
# ---- User: set path to your config.yaml here ----
CONFIG_PATH = Path("config.yaml")  # or "config.yaml" if you saved at root


# -------- helper: small DotDict for cfg access --------
class DotDict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def load_config(path: Path) -> DotDict:
    assert path.exists(), f"Config file not found: {path}"
    with open(path, "r", encoding="utf-8") as f:
        raw = yaml.safe_load(f)
    return DotDict(raw)


def resolve_path(p: str | Path) -> Path:
    """
    Resolve a path relative to current working directory.
    If it's already absolute, returns as-is.
    """
    p = Path(p)
    return p if p.is_absolute() else (Path.cwd() / p).resolve()


def require_exists(path: Path, kind: str):
    assert path.exists(), f"Expected {kind} at: {path}"


def validate_config(cfg: DotDict) -> Dict[str, Path]:
    # Resolve key paths
    dti_path = resolve_path(cfg["data"]["dti_path"])
    adr_root = resolve_path(cfg["data"]["adr_tfidf_root"])
    pdb_dir  = resolve_path(cfg["data"]["pdb_dir"])

    # Validate existence
    require_exists(dti_path, "DTI parquet")
    require_exists(adr_root, "TF-IDF ADR root directory")
    require_exists(adr_root / "idf_table.parquet", "idf_table.parquet in TF-IDF root")
    for split in ("train", "val", "test"):
        split_dir = adr_root / split
        require_exists(split_dir, f"{split} split directory in TF-IDF root")
        require_exists(split_dir / "tfidf_wide.parquet", f"{split}/tfidf_wide.parquet")

    require_exists(pdb_dir, "AlphaFold PDB directory")

    # Optional split files
    splits = cfg["data"]["splits"]
    train_list = resolve_path(splits["train_rxcui_path"]) if splits["train_rxcui_path"] else None
    val_list   = resolve_path(splits["val_rxcui_path"]) if splits["val_rxcui_path"] else None
    test_list  = resolve_path(splits["test_rxcui_path"]) if splits["test_rxcui_path"] else None
    for p in (train_list, val_list, test_list):
        if p is not None:
            require_exists(p, "split list file")

    # Quick schema check on DTI (column names only; no heavy read)
    # We'll fully load in Part 1.
    sample_cols = pd.read_parquet(dti_path, engine="pyarrow", columns=[]).columns if False else None
    # (above line is a placeholder: pyarrow can't read just headers easily; we'll check later in Part 1)

    return {
        "dti_path": dti_path,
        "adr_root": adr_root,
        "pdb_dir": pdb_dir,
        "train_list": train_list,
        "val_list": val_list,
        "test_list": test_list,
    }


def print_summary(cfg: DotDict, paths: Dict[str, Path]):
    print("=== Config summary ===")
    print(f"DTI parquet:         {paths['dti_path']}")
    print(f"TF-IDF ADR root:     {paths['adr_root']}")
    print(f"  - idf_table:       {paths['adr_root'] / 'idf_table.parquet'}")
    print(f"  - splits:          {[p.name for p in (paths['adr_root'] / 'train', paths['adr_root'] / 'val', paths['adr_root'] / 'test')]}")
    print(f"PDB directory:       {paths['pdb_dir']}")
    if paths['train_list'] or paths['val_list'] or paths['test_list']:
        print("Split lists provided:")
        if paths['train_list']: print(f"  - train_rxcui_path: {paths['train_list']}")
        if paths['val_list']:   print(f"  - val_rxcui_path:   {paths['val_list']}")
        if paths['test_list']:  print(f"  - test_rxcui_path:  {paths['test_list']}")
    else:
        print("Split lists:         None (will infer splits from TF-IDF folders)")

    # Peek TF-IDF ADR dimension from idf_table
    idf_path = paths['adr_root'] / "idf_table.parquet"
    try:
        idf_table = pd.read_parquet(idf_path)
        n_adrs = len(idf_table)
        print(f"ADRs kept (columns): {n_adrs}")
    except Exception as e:
        print(f"(Could not read idf_table now; will load in Part 1) -> {e}")

    print("======================\n")


# ---- Load & validate ----
cfg = load_config(CONFIG_PATH)
paths = validate_config(cfg)
print_summary(cfg, paths)

# We'll reuse `cfg` and `paths` in Part 1.


=== Config summary ===
DTI parquet:         F:\Thesis Korbi na\dti-prediction-with-adr\Data\scope_onside_common_v3.parquet
TF-IDF ADR root:     F:\Thesis Korbi na\dti-prediction-with-adr\Data\TFIDF_ADR_vectors
  - idf_table:       F:\Thesis Korbi na\dti-prediction-with-adr\Data\TFIDF_ADR_vectors\idf_table.parquet
  - splits:          ['train', 'val', 'test']
PDB directory:       F:\Thesis Korbi na\dti-prediction-with-adr\AlphaFoldData
Split lists:         None (will infer splits from TF-IDF folders)
ADRs kept (columns): 4048



In [9]:
# -- expects `cfg` and `paths` from Part 0 --

# ---------- ADR metadata container ----------
@dataclass
class ADRInfo:
    idf_table: pd.DataFrame          # columns: meddra_id, df, idf (ordered)
    meddra_ids: np.ndarray           # shape [M], dtype=int64
    adr_cols: List[str]              # ["meddra_<id>", ...] aligned to idf_table order
    n_adrs: int

def _adr_cols_from_idf(idf_table: pd.DataFrame) -> List[str]:
    ids = idf_table["meddra_id"].astype(np.int64).tolist()
    return [f"meddra_{i}" for i in ids]

def load_adr_info(adr_root: Path) -> ADRInfo:
    idf_table = pd.read_parquet(adr_root / "idf_table.parquet")
    idf_table = idf_table.sort_values("meddra_id").reset_index(drop=True)
    meddra_ids = idf_table["meddra_id"].astype(np.int64).to_numpy()
    adr_cols = _adr_cols_from_idf(idf_table)
    return ADRInfo(
        idf_table=idf_table,
        meddra_ids=meddra_ids,
        adr_cols=adr_cols,
        n_adrs=len(adr_cols),
    )

def _ensure_adr_column_order(df: pd.DataFrame, adr_cols: List[str]) -> pd.DataFrame:
    """
    Reorder TF-IDF wide dataframe columns to match adr_cols.
    Add missing columns as zeros. Drop unexpected extras.
    """
    have = set(df.columns)
    need = set(adr_cols)
    missing = list(need - have)
    extras  = list(have - (need | {"rxcui"}))
    if extras:
        df = df.drop(columns=extras)
    if missing:
        # add zero cols with correct float dtype
        zeros = pd.DataFrame(
            np.zeros((len(df), len(missing)), dtype=np.float32),
            index=df.index, columns=missing
        )
        df = pd.concat([df, zeros], axis=1)
    # reorder
    return df[["rxcui"] + adr_cols]

def load_tfidf_split(adr_root: Path, split: str, adr_cols: List[str]) -> pd.DataFrame:
    """
    Load <split>/tfidf_wide.parquet, ensure correct ADR column order, set index to rxcui.
    Returns a DF with index=rxcui and columns=adr_cols (float32).
    """
    path = adr_root / split / "tfidf_wide.parquet"
    df = pd.read_parquet(path)
    assert "rxcui" in df.columns, f"'rxcui' column missing in {path}"
    df = _ensure_adr_column_order(df, adr_cols)
    df[adr_cols] = df[adr_cols].astype(np.float32, copy=False)
    df = df.set_index("rxcui")
    # normalize index to string/stripped
    df.index = df.index.astype(str).str.strip()
    return df

def discover_split_drugs(tfidf_by_split: Dict[str, pd.DataFrame]) -> Dict[str, set]:
    return {k: set(v.index.tolist()) for k, v in tfidf_by_split.items()}

def load_dti_table(dti_path: Path) -> pd.DataFrame:
    # Load only once; enforce schemas
    dti = pd.read_parquet(dti_path)
    required = ["drug_chembl_id","target_uniprot_id","label","smiles","sequence","molfile_3d","rxcui"]
    missing = [c for c in required if c not in dti.columns]
    assert not missing, f"DTI parquet missing columns: {missing}"

    # tidy types
    dti["rxcui"] = dti["rxcui"].astype(str).str.strip()
    dti["target_uniprot_id"] = dti["target_uniprot_id"].astype(str).str.strip()
    dti["label"] = pd.to_numeric(dti["label"], downcast="integer")
    # keep others as string-ish (object OK)
    return dti

def make_per_split_dti(dti_df: pd.DataFrame, split_drugs: Dict[str, set]) -> Dict[str, pd.DataFrame]:
    """
    Filter DTI rows by rxcui membership in each split.
    Note: A drug can appear with many proteins; we split by drug only to avoid leakage.
    """
    out = {}
    for split, drugs in split_drugs.items():
        mask = dti_df["rxcui"].isin(drugs)
        out[split] = dti_df.loc[mask].reset_index(drop=True)
    return out

# ---------- ADR accessors ----------
class ADRAccessor:
    """
    Convenient accessor for ADR vectors aligned to idf_table order.
    Provides weighted and binary vectors; returns zeros for unknown rxcui.
    """
    def __init__(self, tfidf_by_split: Dict[str, pd.DataFrame], adr_cols: List[str]):
        self.tfidf = tfidf_by_split   # each is DF indexed by rxcui
        self.adr_cols = adr_cols
        self.n_adrs = len(adr_cols)
        # Precompute numpy views for speed
        self._np = {k: v[self.adr_cols].to_numpy(dtype=np.float32, copy=False) for k, v in tfidf_by_split.items()}
        self._idx = {k: {rx:i for i, rx in enumerate(df.index)} for k, df in tfidf_by_split.items()}

    def get_weighted(self, split: str, rxcui: str) -> np.ndarray:
        idx = self._idx[split].get(str(rxcui).strip(), None)
        if idx is None:
            return np.zeros(self.n_adrs, dtype=np.float32)
        return self._np[split][idx]

    def get_binary(self, split: str, rxcui: str) -> np.ndarray:
        w = self.get_weighted(split, rxcui)
        # TF-IDF>0 -> 1 else 0
        return (w > 0).astype(np.float32)

    def df(self, split: str) -> pd.DataFrame:
        return self.tfidf[split][self.adr_cols]

# ---------------- Load everything ----------------
adr_info = load_adr_info(paths["adr_root"])
tfidf_by_split = {
    "train": load_tfidf_split(paths["adr_root"], "train", adr_info.adr_cols),
    "val":   load_tfidf_split(paths["adr_root"], "val",   adr_info.adr_cols),
    "test":  load_tfidf_split(paths["adr_root"], "test",  adr_info.adr_cols),
}
split_drugs = discover_split_drugs(tfidf_by_split)

# sanity on split disjointness
overlaps = {
    ("train","val"): len(split_drugs["train"] & split_drugs["val"]),
    ("train","test"): len(split_drugs["train"] & split_drugs["test"]),
    ("val","test"): len(split_drugs["val"] & split_drugs["test"]),
}
print("Split overlaps (should be 0):", overlaps)

# Load DTI table and make per-split views by drug
dti_df = load_dti_table(paths["dti_path"])
dti_by_split = make_per_split_dti(dti_df, split_drugs)

# Report basic shapes
print(f"DTI total rows: {len(dti_df):,}")
for s in ("train","val","test"):
    n_rows = len(dti_by_split[s])
    n_drugs = len(split_drugs[s])
    print(f"{s:>5}: {n_rows:>7,} rows | {n_drugs:>5,} drugs | ADR matrix shape = {tfidf_by_split[s].shape}")

# Build ADR accessor
adr_accessor = ADRAccessor(tfidf_by_split, adr_info.adr_cols)

# Quick peek: fetch ADR vector for first train drug
if len(split_drugs["train"]):
    example_rx = next(iter(split_drugs["train"]))
    w = adr_accessor.get_weighted("train", example_rx)
    b = adr_accessor.get_binary("train", example_rx)
    print(f"\nExample train rxcui: {example_rx}")
    print(f"  Weighted ADR vector (len={len(w)}), nnz={int((w>0).sum())}")
    print(f"  Binary ADR vector   (len={len(b)}), nnz={int(b.sum())}")


Split overlaps (should be 0): {('train', 'val'): 0, ('train', 'test'): 0, ('val', 'test'): 0}
DTI total rows: 34,741
train:  22,310 rows |   719 drugs | ADR matrix shape = (719, 4048)
  val:   4,835 rows |   154 drugs | ADR matrix shape = (154, 4048)
 test:   7,596 rows |   155 drugs | ADR matrix shape = (155, 4048)

Example train rxcui: 6676
  Weighted ADR vector (len=4048), nnz=4
  Binary ADR vector   (len=4048), nnz=4
