In [20]:
#!/usr/bin/env python3
"""
Modular RxNorm ↔ PubChem/ChEMBL/InChIKey pipeline.

Run stages independently:
  A) Historystatus & lineage   → debug_historystatus.tsv, lineage.parquet
  B) Resolution (current RXCUI → CID/InChIKey) → rxnorm_resolved.parquet
  C) Join (single full parquet) → mapping_combined.parquet

Toggle STAGE_* flags and options in the CONTROL BLOCK at the end.
"""

from typing import List, Dict, Any, Optional, Tuple, Set, Iterable
from pathlib import Path
import time
import threading
import re
import csv
import json
import pandas as pd
from tqdm import tqdm
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed

In [21]:
# ========================= CONFIG (paths common to all runs) =========================
DATASET_A_DIR = Path(r"../Data/Helper_Datasets/merged_id.parquet")  # single parquet file or a dir (defensive fallback)
A_GLOB = "*.parquet"

ONSIDES_CSV_DIR = Path(r"../Data/Main_Datasets/onsides-v3.1.0")
ONSIDES_INGREDIENT_CSV = ONSIDES_CSV_DIR / "vocab_rxnorm_ingredient.csv"              # has rxnorm_id or rxcui
ONSIDES_PRODUCT_TO_RXNORM_CSV = ONSIDES_CSV_DIR / "product_to_rxnorm.csv"             # has rxnorm_product_id
# (Optional helper) if you ever want to pivot product↔ingredient; not required for product lineage itself:
ONSIDES_RXNORM_ING_TO_PRODUCT_CSV = ONSIDES_CSV_DIR / "vocab_rxnorm_ingredient_to_product.csv"

OUT_DIR = Path("mapping_out/")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Canonical cache/artifacts (reused across runs)
HISTORY_TSV_PATH = OUT_DIR / "debug_historystatus.tsv"     # full historystatus (audit) with source_type
LINEAGE_PATH     = OUT_DIR / "lineage.parquet"             # columns: original_rxcui, rxcui, source_type
CURRENT_SET_JSON = OUT_DIR / "current_rxcui_set.json"      # {"ingredient":[...], "product":[...], "all":[...]}
RESOLVED_PATH    = OUT_DIR / "rxnorm_resolved.parquet"     # current rxcui → pubchem_cid/inchikey cache (shared)
COMBINED_OUT     = OUT_DIR / "mapping_combined.parquet"
RES_DEBUG_TSV    = OUT_DIR / "debug_resolution.tsv"

# Column names in your A shards / full file
COL_CHEMBL   = "chembl_id"
COL_PUBCHEM  = "pubchem_id"   # normalized to pubchem_cid
COL_INCHIKEY = "inchikey"

# Rate limiting
REQUESTS_PER_SECOND = 4.0
HTTP_TIMEOUT_SEC = 20

# Threading
NUM_WORKERS_RESOLVE = 8

# ---- Endpoints ----
RXNAV   = "https://rxnav.nlm.nih.gov/REST"
PUBCHEM = "https://pubchem.ncbi.nlm.nih.gov/rest/pug"
HEADERS = {"User-Agent": "rxnorm-pubchem-mapper-batch/6.0", "Accept": "application/json"}

In [22]:
# ========================= Utilities =========================
class TokenBucket:
    def __init__(self, rate_per_sec: float, capacity: Optional[float] = None):
        self.rate = float(rate_per_sec)
        self.capacity = capacity if capacity is not None else max(1.0, self.rate)
        self.tokens = self.capacity
        self.timestamp = time.monotonic()
        self.lock = threading.Lock()
    def consume(self, tokens: float = 1.0):
        while True:
            with self.lock:
                now = time.monotonic()
                elapsed = now - self.timestamp
                self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
                self.timestamp = now
                if self.tokens >= tokens:
                    self.tokens -= tokens
                    return
            time.sleep(max(0.001, tokens / self.rate / 2))

bucket = TokenBucket(REQUESTS_PER_SECOND)
session = requests.Session()

def safe_get(url: str, params: Dict[str, Any] = None, max_retries: int = 3) -> Optional[Dict[str, Any]]:
    for attempt in range(max_retries):
        try:
            bucket.consume(1.0)
            r = session.get(url, params=params, headers=HEADERS, timeout=HTTP_TIMEOUT_SEC)
            if r.status_code == 200:
                return r.json()
            if r.status_code in (400, 404):
                return None
        except requests.RequestException:
            pass
        time.sleep(min(10.0, (2 ** attempt)))
    return None

# Sanitizers
_DB_RE = re.compile(r"(DB\d{5})", re.IGNORECASE)
_UNII_RE = re.compile(r"[A-Z0-9]{4,12}")
def clean_drugbank_id(raw: str) -> Optional[str]:
    if not raw: return None
    s = raw.strip().upper().replace("DRUGBANK:", " ").replace("DBANK:", " ").replace("DB:", " ")
    m = _DB_RE.search(s)
    return m.group(1).upper() if m else None
def clean_unii(raw: str) -> Optional[str]:
    if not raw: return None
    s = raw.strip().upper().replace("UNII:", " ").replace("UNII ", " ")
    m = _UNII_RE.search(s)
    return m.group(0) if m else None

In [23]:
# ========================= RxNav helpers =========================
def rxnav_get_historystatus(rxcui: str) -> Dict[str, Any]:
    url = f"{RXNAV}/rxcui/{rxcui}/historystatus.json"
    return safe_get(url) or {}

def _parse_historystatus_payload(data: dict) -> tuple[str, list[str]]:
    if not data: return ("Unknown", [])
    hs = data.get("historystatus")
    if isinstance(hs, dict):
        status = hs.get("conceptStatus", "Unknown")
        remapped = hs.get("remappedConcept", [])
        if isinstance(remapped, dict): remapped = [remapped]
        targets = []
        for x in remapped:
            rcui = x.get("rxcui") or x.get("remappedRxCui")
            if rcui: targets.append(str(rcui))
        return (status, sorted(set(targets)))
    h2 = data.get("rxcuiStatusHistory")
    if isinstance(h2, dict):
        status = (h2.get("metaData") or {}).get("status", "Unknown")
        remapped = (h2.get("derivedConcepts") or {}).get("remappedConcept", []) or []
        if isinstance(remapped, dict): remapped = [remapped]
        targets = []
        for x in remapped:
            rcui = x.get("remappedRxCui") or x.get("rxcui")
            if rcui: targets.append(str(rcui))
        return (status, sorted(set(targets)))
    return ("Unknown", [])

def normalize_rxcui(rxcui: str) -> list[str]:
    payload = rxnav_get_historystatus(str(rxcui))
    status, targets = _parse_historystatus_payload(payload)
    if status and status.upper() == "ACTIVE":   return [str(rxcui)]
    if status and status.upper() == "REMAPPED": return targets
    return []

def rxnav_properties_min(rxcui: str) -> Tuple[Optional[str], Optional[str]]:
    url = f"{RXNAV}/rxcui/{rxcui}/properties.json"
    data = safe_get(url) or {}
    props = data.get("properties", {})
    return (props.get("name"), props.get("tty"))

def rxnav_all_properties(rxcui: str) -> Dict[str, Any]:
    url = f"{RXNAV}/rxcui/{rxcui}/allProperties.json"
    return safe_get(url, params={"prop": "names codes"}) or {}

def extract_codes_from_all_props(all_props: Dict[str, Any]) -> Dict[str, List[str]]:
    out: Dict[str, List[str]] = {}
    group = all_props.get("propConceptGroup", {})
    pcs = group.get("propConcept")
    if not pcs: return out
    if isinstance(pcs, dict): pcs = [pcs]
    for pc in pcs:
        cat = (pc.get("propCategory") or "").upper()
        name = (pc.get("propName") or "").upper()
        val = pc.get("propValue")
        if not val: continue
        key = None
        if cat == "CODES":
            if name == "DRUGBANK":   key = "drugbank_id"
            elif name == "UNII_CODE": key = "unii"
        elif cat == "NAMES":
            if name in ("RXNORM NAME", "RXNAV_STR"): key = "rxnorm_name"
        if key:
            out.setdefault(key, []).append(val)
    return out

In [24]:
# ========================= PubChem helpers =========================
def pubchem_sids_from_drugbank(dbid: str) -> List[int]:
    url = f"{PUBCHEM}/substance/sourceid/DrugBank/{dbid}/sids/JSON"
    data = safe_get(url) or {}
    return data.get("IdentifierList", {}).get("SID", []) or []
def pubchem_sids_from_unii(unii: str) -> List[int]:
    url = f"{PUBCHEM}/substance/xref/RegistryID/{unii}/sids/JSON"
    data = safe_get(url) or {}
    return data.get("IdentifierList", {}).get("SID", []) or []
def pubchem_cids_from_sid(sid: int) -> List[int]:
    url = f"{PUBCHEM}/substance/sid/{sid}/cids/JSON"
    data = safe_get(url) or {}
    info = data.get("InformationList", {}).get("Information", [])
    if not info: return []
    return info[0].get("CID", []) or []
def pubchem_cids_from_name_exact(name: str) -> List[int]:
    url = f"{PUBCHEM}/compound/name/{requests.utils.quote(name)}/cids/JSON"
    data = safe_get(url, params={"name_type": "exact"}) or {}
    return data.get("IdentifierList", {}).get("CID", []) or []
def pubchem_inchikey_from_cid(cid: int) -> Optional[str]:
    url = f"{PUBCHEM}/compound/cid/{cid}/property/InChIKey/JSON"
    data = safe_get(url) or {}
    props = data.get("PropertyTable", {}).get("Properties", [])
    if not props: return None
    return props[0].get("InChIKey")

In [25]:
# ========================= Stage A: Historystatus + Lineage (ING + PROD) =========================
def load_onsides_ingredient_rxcui(ingredient_csv: Path) -> pd.DataFrame:
    if not ingredient_csv.exists():
        raise FileNotFoundError(f"{ingredient_csv} not found")
    df = pd.read_csv(ingredient_csv)
    rx_col = "rxnorm_id" if "rxnorm_id" in df.columns else "rxcui" if "rxcui" in df.columns else None
    if rx_col is None:
        raise ValueError("vocab_rxnorm_ingredient.csv must contain 'rxnorm_id' or 'rxcui'")
    out = df[[rx_col]].rename(columns={rx_col: "original_rxcui"}).dropna().drop_duplicates()
    out["original_rxcui"] = out["original_rxcui"].astype(str)
    out["source_type"] = "ingredient"
    return out

def load_onsides_product_rxcui(product_to_rxnorm_csv: Path) -> pd.DataFrame:
    if not product_to_rxnorm_csv.exists():
        raise FileNotFoundError(f"{product_to_rxnorm_csv} not found")
    dfp = pd.read_csv(product_to_rxnorm_csv)
    if "rxnorm_product_id" not in dfp.columns:
        raise ValueError("product_to_rxnorm.csv must contain 'rxnorm_product_id'")
    out = dfp[["rxnorm_product_id"]].rename(columns={"rxnorm_product_id": "original_rxcui"}).dropna().drop_duplicates()
    out["original_rxcui"] = out["original_rxcui"].astype(str)
    out["source_type"] = "product"
    return out

def stage_A_build_historystatus_and_lineage(
    include_types: Iterable[str] = ("ingredient","product"),
    overwrite_tsv: bool = False,
    overwrite_lineage: bool = False
) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
    print("[A] Loading RxCUIs from OnSIDES ...")
    parts = []
    if "ingredient" in include_types:
        parts.append(load_onsides_ingredient_rxcui(ONSIDES_INGREDIENT_CSV))
    if "product" in include_types:
        parts.append(load_onsides_product_rxcui(ONSIDES_PRODUCT_TO_RXNORM_CSV))
    if not parts:
        raise ValueError("include_types must contain 'ingredient' and/or 'product'")
    df_orig = pd.concat(parts, ignore_index=True).drop_duplicates()
    print(f"    found {len(df_orig)} original RxCUIs (ingredients: {sum(df_orig['source_type']=='ingredient')}, products: {sum(df_orig['source_type']=='product')})")

    # Historystatus + lineage (one pass) with source_type
    print("[A] Normalizing via historystatus (one pass) ...")
    lineage_rows = []
    if overwrite_tsv or not HISTORY_TSV_PATH.exists():
        with HISTORY_TSV_PATH.open("w", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=["original_rxcui", "status", "targets", "source_type"], delimiter="\t")
            w.writeheader()
            for i, row in enumerate(tqdm(df_orig.itertuples(index=False), total=len(df_orig), desc="Historystatus (RxCUI remaps)", unit="rxcui")):
                rx = row.original_rxcui
                st = row.source_type
                targets = normalize_rxcui(rx)
                status = "Active" if targets == [rx] else ("Remapped" if targets else "Obsolete/None")
                w.writerow({"original_rxcui": rx, "status": status, "targets": ";".join(targets), "source_type": st})
                if targets:
                    for t in targets:
                        lineage_rows.append({"original_rxcui": rx, "rxcui": t, "source_type": st})
                else:
                    lineage_rows.append({"original_rxcui": rx, "rxcui": None, "source_type": st})
                if (i + 1) % 500 == 0: f.flush()
    else:
        print("    using existing debug_historystatus.tsv (set overwrite_tsv=True to re-create)")
        # still rebuild lineage for downstream using live API (keeps fresh)
        for row in tqdm(df_orig.itertuples(index=False), total=len(df_orig), desc="Historystatus (RxCUI remaps)", unit="rxcui"):
            rx = row.original_rxcui
            st = row.source_type
            targets = normalize_rxcui(rx)
            if targets:
                for t in targets:
                    lineage_rows.append({"original_rxcui": rx, "rxcui": t, "source_type": st})
            else:
                lineage_rows.append({"original_rxcui": rx, "rxcui": None, "source_type": st})

    df_lineage = pd.DataFrame(lineage_rows).drop_duplicates()
    if overwrite_lineage or not LINEAGE_PATH.exists():
        df_lineage.to_parquet(LINEAGE_PATH, index=False)
    else:
        try:
            old = pd.read_parquet(LINEAGE_PATH)
            df_lineage = pd.concat([old, df_lineage], ignore_index=True).drop_duplicates()
            df_lineage.to_parquet(LINEAGE_PATH, index=False)
        except Exception:
            df_lineage.to_parquet(LINEAGE_PATH, index=False)

    # Build current sets per type + combined
    current_ing = sorted(df_lineage.query("source_type=='ingredient' and rxcui.notna()")["rxcui"].astype(str).unique().tolist())
    current_prod = sorted(df_lineage.query("source_type=='product' and rxcui.notna()")["rxcui"].astype(str).unique().tolist())
    current_all  = sorted(set(current_ing) | set(current_prod))
    current_sets = {"ingredient": current_ing, "product": current_prod, "all": current_all}

    with CURRENT_SET_JSON.open("w", encoding="utf-8") as f:
        json.dump(current_sets, f)

    n_orphans = (df_lineage["rxcui"].isna()).sum()
    print(f"[A] current RxCUIs → ingredient:{len(current_ing)} | product:{len(current_prod)} | all:{len(current_all)}  (orphans: {n_orphans})")
    return df_lineage, current_sets

In [26]:
# ========================= Stage B: Resolve current RXCUIs (shared cache) =========================
def resolve_single_rxcui_current(rxcui: str, debug_rows: list) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    rxn_name, tty = rxnav_properties_min(rxcui)
    all_props = rxnav_all_properties(rxcui)
    codes = extract_codes_from_all_props(all_props)
    db_ids = [clean_drugbank_id(x) for x in codes.get("drugbank_id", []) if clean_drugbank_id(x)]
    unii_ids = [clean_unii(x) for x in codes.get("unii", []) if clean_unii(x)]

    resolved_cids: Set[int] = set()
    prov: List[str] = []
    dbg = {"rxcui": rxcui, "rxnorm_name": rxn_name, "tty": tty, "db_ids": ";".join(db_ids), "unii_ids": ";".join(unii_ids), "step": "", "result_cids": ""}

    # DrugBank → Substance → SID → CID
    for dbid in db_ids:
        sids = pubchem_sids_from_drugbank(dbid)
        for sid in sids:
            for cid in pubchem_cids_from_sid(sid):
                resolved_cids.add(int(cid))
        if sids: prov.append(f"DrugBank:{dbid}")
    if resolved_cids:
        dbg["step"] = "DrugBank"; dbg["result_cids"] = ",".join(map(str, sorted(resolved_cids)))

    # UNII → Substance → SID → CID
    if not resolved_cids:
        for unii in unii_ids:
            sids = pubchem_sids_from_unii(unii)
            for sid in sids:
                for cid in pubchem_cids_from_sid(sid):
                    resolved_cids.add(int(cid))
            if sids: prov.append(f"UNII:{unii}")
        if resolved_cids:
            dbg["step"] = "UNII"; dbg["result_cids"] = ",".join(map(str, sorted(resolved_cids)))

    # RxNorm preferred name exact (heuristic but conservative)
    heuristic = False
    if not resolved_cids and rxn_name:
        for cid in pubchem_cids_from_name_exact(rxn_name):
            resolved_cids.add(int(cid))
        if resolved_cids:
            heuristic = True
            prov.append(f"RxNormName:{rxn_name}")
            dbg["step"] = "NAME"; dbg["result_cids"] = ",".join(map(str, sorted(resolved_cids)))

    # Emit rows
    if resolved_cids:
        for cid in sorted(resolved_cids):
            ik = pubchem_inchikey_from_cid(cid)
            rows.append({"rxcui": rxcui, "rxnorm_name": rxn_name, "tty": tty, "pubchem_cid": cid,
                         "inchikey": (ik or None), "drugbank_id": ";".join(db_ids) or None, "unii": ";".join(unii_ids) or None,
                         "confidence": "XREF" if not heuristic else "HEURISTIC_NAME", "provenance": ",".join(prov) or None})
    else:
        rows.append({"rxcui": rxcui, "rxnorm_name": rxn_name, "tty": tty, "pubchem_cid": None, "inchikey": None,
                     "drugbank_id": ";".join(db_ids) or None, "unii": ";".join(unii_ids) or None,
                     "confidence": "UNRESOLVED", "provenance": None})
    debug_rows.append(dbg)
    return rows

def stage_B_resolve_current(current_set: List[str], append_cache: bool = True) -> pd.DataFrame:
    print("[B] Resolving current RxCUIs → CID/InChIKey ...")

    cached = pd.DataFrame()
    if RESOLVED_PATH.exists():
        cached = pd.read_parquet(RESOLVED_PATH)
        cached["rxcui"] = cached["rxcui"].astype(str)

    to_do = set(current_set)
    if not cached.empty:
        done = set(cached["rxcui"].unique())
        to_do = sorted(list(to_do - done))
    else:
        to_do = sorted(list(to_do))

    print(f"    total current RXCUIs: {len(current_set)} | cached: {0 if cached.empty else len(cached['rxcui'].unique())} | to resolve: {len(to_do)}")

    out_rows: List[Dict[str, Any]] = []
    debug_rows: List[Dict[str, Any]] = []
    if to_do:
        with ThreadPoolExecutor(max_workers=NUM_WORKERS_RESOLVE) as ex:
            futures = {ex.submit(resolve_single_rxcui_current, r, debug_rows): r for r in to_do}
            for fut in tqdm(as_completed(futures), total=len(futures), desc="Resolving (parallel)"):
                out_rows.extend(fut.result())

    dbg_df = pd.DataFrame(debug_rows)
    if not dbg_df.empty:
        dbg_df.to_csv(RES_DEBUG_TSV, sep="\t", index=False)

    new_df = pd.DataFrame(out_rows)
    if not new_df.empty:
        if "inchikey" in new_df.columns:
            new_df["inchikey"] = new_df["inchikey"].astype(str).str.upper()
        final = pd.concat([cached, new_df], ignore_index=True).drop_duplicates() if append_cache else new_df
    else:
        final = cached

    final.to_parquet(RESOLVED_PATH, index=False)
    print(f"[B] cache updated at {RESOLVED_PATH} | total rows: {len(final)}")
    return final

In [27]:
# ========================= Stage C: Join single file =========================
def normalize_a_cols(df: pd.DataFrame) -> pd.DataFrame:
    rename = {}
    if COL_PUBCHEM in df.columns and "pubchem_cid" not in df.columns:
        rename[COL_PUBCHEM] = "pubchem_cid"
    if COL_INCHIKEY in df.columns and "inchikey" not in df.columns:
        rename[COL_INCHIKEY] = "inchikey"
    if rename:
        df = df.rename(columns=rename)

    if "pubchem_cid" in df.columns:
        df["pubchem_cid"] = pd.to_numeric(df["pubchem_cid"], errors="coerce").astype("Int64")
    if "inchikey" in df.columns:
        df["inchikey"] = df["inchikey"].astype(str).str.upper()
    return df

def join_full_with_rx(df_a: pd.DataFrame, df_rx: pd.DataFrame) -> pd.DataFrame:
    """
    Join A (your whole parquet) with resolved RxNorm mapping in two passes:
      1) exact on InChIKey
      2) then on remaining rows by PubChem CID
    """
    if df_a.empty:
        return pd.DataFrame(columns=[
            "original_rxcui","rxcui","source_type","pubchem_cid","inchikey",
            COL_CHEMBL,"rxnorm_name","tty","drugbank_id","unii","confidence","provenance"
        ])

    # Pass 1: inchikey
    out_ik = (
        df_a.merge(df_rx, on="inchikey", how="inner", suffixes=("_a", "_rx"))
        if ("inchikey" in df_a.columns and "inchikey" in df_rx.columns)
        else pd.DataFrame()
    )

    # Pass 2: remaining by pubchem_cid
    out = out_ik
    if "pubchem_cid" in df_a.columns and "pubchem_cid" in df_rx.columns:
        matched_idx = out_ik.index if not out_ik.empty else []
        remaining = df_a[~df_a.index.isin(matched_idx)] if len(df_a) else df_a

        left = remaining.copy()
        right = df_rx.copy()
        left["pubchem_cid"]  = pd.to_numeric(left["pubchem_cid"], errors="coerce").astype("Int64")
        right["pubchem_cid"] = pd.to_numeric(right["pubchem_cid"], errors="coerce").astype("Int64")

        out_cid = left.merge(right, on="pubchem_cid", how="inner", suffixes=("_a", "_rx"))
        out = pd.concat([out_ik, out_cid], ignore_index=True) if not out_ik.empty else out_cid

    # Keep lineage/type and important columns
    keep = [
        "original_rxcui","rxcui","source_type","pubchem_cid","inchikey",COL_CHEMBL,
        "rxnorm_name","tty","drugbank_id","unii","confidence","provenance"
    ]
    for c in keep:
        if c not in out.columns and f"{c}_rx" in out.columns:
            out[c] = out[f"{c}_rx"]
        elif c not in out.columns and f"{c}_a" in out.columns:
            out[c] = out[f"{c}_a"]
        elif c not in out.columns:
            out[c] = None

    out = out[keep].drop_duplicates()
    return out

def stage_C_join_single(join_types: Iterable[str] = ("ingredient","product")) -> Path:
    print("[C] Joining full dataset (single parquet) ...")
    # Preconditions
    if not LINEAGE_PATH.exists():
        raise FileNotFoundError("lineage.parquet not found. Run Stage A first.")
    if not RESOLVED_PATH.exists():
        raise FileNotFoundError("rxnorm_resolved.parquet not found. Run Stage B first.")

    # Load lineage filtered by type and merge with resolved cache
    df_lineage_all = pd.read_parquet(LINEAGE_PATH)
    df_lineage = df_lineage_all[df_lineage_all["source_type"].isin(set(join_types))].copy()

    df_rx_base = pd.read_parquet(RESOLVED_PATH)
    df_rx = df_lineage.merge(df_rx_base, on="rxcui", how="left")

    # Read the whole A file at once (defensive if a directory is passed)
    a_path = Path(DATASET_A_DIR)
    if not a_path.exists():
        raise FileNotFoundError(f"{a_path} not found")
    if a_path.is_dir():
        files = sorted(a_path.glob(A_GLOB))
        if not files:
            raise FileNotFoundError(f"No parquet files found under {a_path} matching {A_GLOB}")
        df_a = pd.concat([pd.read_parquet(p) for p in files], ignore_index=True)
    else:
        df_a = pd.read_parquet(a_path)

    df_a = normalize_a_cols(df_a)

    # If neither inchikey nor pubchem_cid exists, emit empty combined mapping
    if "inchikey" not in df_a.columns and "pubchem_cid" not in df_a.columns:
        print("    No 'inchikey' or 'pubchem_cid' columns in A; writing empty combined mapping.")
        pd.DataFrame(columns=[
            "original_rxcui","rxcui","source_type","pubchem_cid","inchikey",COL_CHEMBL,
            "rxnorm_name","tty","drugbank_id","unii","confidence","provenance"
        ]).to_parquet(COMBINED_OUT, index=False)
        print(f"[C] combined mapping written to {COMBINED_OUT}")
        return COMBINED_OUT

    # Join and write
    out = join_full_with_rx(df_a, df_rx)
    out.to_parquet(COMBINED_OUT, index=False)
    print(f"[C] combined mapping written to {COMBINED_OUT} | rows: {len(out)}")
    return COMBINED_OUT

In [None]:
# ========================= CONTROL BLOCK =========================
if __name__ == "__main__":
    # Toggle stages independently
    STAGE_A = True   # build historystatus + lineage
    STAGE_B = True   # resolve current RXCUIs (shared cache)
    STAGE_C = True   # join full dataset (single parquet)

    # Stage A options
    INCLUDE_TYPES = ("ingredient","product")  # choose any of: "ingredient", "product"
    OVERWRITE_HISTORY_TSV = False
    OVERWRITE_LINEAGE     = False

    # Stage B options
    # (no change—uses CURRENT_SET_JSON["all"] to resolve once for both types)

    # Stage C (join) selection — single file mode
    JOIN_TYPES   = ("ingredient","product")   # which lineage types to join into the mapping

    # ===== Run =====
    lineage_df, current_sets = (None, None)
    if STAGE_A:
        lineage_df, current_sets = stage_A_build_historystatus_and_lineage(
            include_types=INCLUDE_TYPES,
            overwrite_tsv=OVERWRITE_HISTORY_TSV,
            overwrite_lineage=OVERWRITE_LINEAGE
        )
    else:
        if not LINEAGE_PATH.exists() or not CURRENT_SET_JSON.exists():
            raise FileNotFoundError("Artifacts missing. Run STAGE_A once to create lineage & current set.")
        lineage_df = pd.read_parquet(LINEAGE_PATH)
        with CURRENT_SET_JSON.open("r", encoding="utf-8") as f:
            current_sets = json.load(f)

    if STAGE_B:
        # Resolve across all current RXCUIs from both types (deduped) so cache is shared
        stage_B_resolve_current(current_sets["all"], append_cache=True)

    if STAGE_C:
        stage_C_join_single(join_types=JOIN_TYPES)