
# Supp



In [None]:

import os
import re
import math
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any

import numpy as np
import pandas as pd


# =====================
# Redaction-safe logging
# =====================
def _log(msg: str) -> None:
    # Only aggregate, non-sensitive logs in SAFE_MODE
    if SAFE_MODE:
        print(msg, flush=True)
    else:
        print(msg, flush=True)

def _log_debug(msg: str) -> None:
    if (not SAFE_MODE) and VERBOSE:
        print(msg, flush=True)

def _raise_if_public_run_on_text() -> None:
    if SAFE_MODE and not ALLOW_TEXT_INPUT:
        raise RuntimeError(
            "SAFE_MODE is ON and ALLOW_TEXT_INPUT is OFF.\n"
            "This public-safe script refuses to process real text by default.\n"
            "Flip ALLOW_TEXT_INPUT=True only in an approved secure environment."
        )

def _assert_columns(df: pd.DataFrame, required: List[str], name: str = "df") -> None:
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"{name} missing required columns: {missing}")


# =====================
# I. Topic Mapping (prompt template only)
# =====================
UPDATE_SYSTEM = """
You are updating a 3-level clinical taxonomy for de-identified clinical
portal messages: MAIN -> SUB1 -> SUB2.

You must:
- **KEEP** existing categories as much as possible.
- ADD new MAIN/SUB1/SUB2 categories to **CAPTURE NEW TOPICS**
- **DO NOT merge/remove** unless those are almost identical.
- Keep the taxonomy interpretable and moderately granular.
- For symptom taxonomy, be **especially GRANULAR**.
- In the REASONING line, only describe what you added, merged, or renamed.
Do NOT include general statements like “kept existing structure intact.”

Write your response in plain text.
""".strip()


# =====================
# II. BERTopic-based validation
# =====================

STOP = {"and","or","the","of","for","to","in","on","a","an","with","without"}

def normalize_text(x: str) -> str:
    x = "" if pd.isna(x) else str(x)
    return re.sub(r"\s+", " ", x.strip())

def phrase_to_keywords(phrase: str) -> List[str]:
    phrase = normalize_text(phrase).lower()
    toks = re.findall(r"[a-z]+", phrase)
    toks = [t for t in toks if t not in STOP and len(t) > 2]
    return toks

def build_seed_topic_list(seed_df: pd.DataFrame, main_col: str = "main", sub1_col: str = "sub1"):
    """
    seed_df should be non-sensitive / publishable (or synthetic).
    """
    if seed_df is None or len(seed_df) == 0:
        raise ValueError("seed_df is empty. Provide a non-sensitive seed taxonomy (main/sub1).")

    _assert_columns(seed_df, [main_col, sub1_col], name="seed_df")

    tmp = seed_df[[main_col, sub1_col]].copy()
    tmp[main_col] = tmp[main_col].map(normalize_text)
    tmp[sub1_col] = tmp[sub1_col].map(normalize_text)
    tmp = tmp.drop_duplicates(subset=[main_col, sub1_col])

    grouped = tmp.groupby(main_col)[sub1_col].apply(list)

    ontology_labels: List[str] = []
    seed_topic_list: List[List[str]] = []

    for main_label, sub1_list in grouped.items():
        main_kw = phrase_to_keywords(main_label)
        sub_kw: List[str] = []
        for s in sub1_list:
            sub_kw.extend(phrase_to_keywords(s))

        seen, kws = set(), []
        for w in (main_kw + sub_kw):
            if w not in seen:
                kws.append(w)
                seen.add(w)

        ontology_labels.append(main_label)
        seed_topic_list.append(kws)

    return ontology_labels, seed_topic_list, tmp


# ---- embeddings (in-memory only) ----
import torch
from sentence_transformers import SentenceTransformer

def compute_embeddings(texts: List[str], embedder_name: str = "all-MiniLM-L6-v2", batch_size: int = 128):
    _raise_if_public_run_on_text()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    embedder = SentenceTransformer(embedder_name, device=device)
    emb = embedder.encode(
        list(texts),
        batch_size=batch_size,
        show_progress_bar=(not SAFE_MODE),  # reduce chatter in public runs
        convert_to_numpy=True,
        normalize_embeddings=False,
    )
    return embedder, emb


# ---- BERTopic ----
from sklearn.feature_extraction.text import CountVectorizer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from umap import UMAP
from hdbscan import HDBSCAN

def fit_guided_bertopic(df: pd.DataFrame, seed_topic_list, embedder, embeddings):
    _raise_if_public_run_on_text()
    _assert_columns(df, [ID_COL, TEXT_COL], name="df")

    vectorizer_model = CountVectorizer(
        stop_words="english",
        ngram_range=(1, 2),
        min_df=2,
        max_df=0.8,
    )

    umap_model = UMAP(
        n_neighbors=30,
        n_components=10,
        min_dist=0.0,
        metric="cosine",
        random_state=42,
    )

    hdbscan_model = HDBSCAN(
        min_cluster_size=30,
        min_samples=10,
        prediction_data=True,
    )

    model = BERTopic(
        embedding_model=embedder,
        umap_model=umap_model,
        hdbscan_model=hdbscan_model,
        vectorizer_model=vectorizer_model,
        representation_model=KeyBERTInspired(),
        seed_topic_list=seed_topic_list,
        top_n_words=10,
        calculate_probabilities=True,
        verbose=False,
        low_memory=True,
    )

    topics, probs = model.fit_transform(df[TEXT_COL].astype(str).tolist(), embeddings=embeddings)
    info = model.get_topic_info()

    # SAFE_MODE: never print topic names/words
    _log(f"[INFO] BERTopic fit complete: n_docs={len(df):,}, n_topics={(info['Topic']!=-1).sum():,}, outliers={(info['Topic']==-1).sum():,}")
    return model, topics, probs, info


def _cosine_sim_matrix(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    A = A / (np.linalg.norm(A, axis=1, keepdims=True) + 1e-12)
    B = B / (np.linalg.norm(B, axis=1, keepdims=True) + 1e-12)
    return A @ B.T

def align_topics_to_ontology_safer(model, embedder, ontology_labels: List[str]) -> pd.DataFrame:
    """
    SAFER alignment:
    - Prefer topic centroid embeddings if available
    - Otherwise use topic "Name" embeddings as fallback (still derived from corpus)
    In SAFE_MODE, we will NOT export or print the Name strings.
    """
    info = model.get_topic_info().copy()
    info = info[info["Topic"] != -1].reset_index(drop=True)
    topic_ids = info["Topic"].tolist()

    topic_vecs = None

    # Try BERTopic topic centroids
    if hasattr(model, "topic_embeddings_") and model.topic_embeddings_ is not None:
        topics_dict = model.get_topics()
        ordered_topic_ids = [t for t in topics_dict.keys() if t != -1]
        if len(ordered_topic_ids) == len(model.topic_embeddings_):
            id_to_vec = {tid: model.topic_embeddings_[i] for i, tid in enumerate(ordered_topic_ids)}
            vecs = []
            for tid in topic_ids:
                if tid in id_to_vec:
                    vecs.append(id_to_vec[tid])
            if len(vecs) == len(topic_ids):
                topic_vecs = np.vstack(vecs)

    # Fallback: encode topic Name (avoid returning the raw strings)
    if topic_vecs is None:
        names = info["Name"].astype(str).tolist()
        topic_vecs = np.asarray(embedder.encode(names, normalize_embeddings=True))

    onto_vecs = np.asarray(embedder.encode(list(ontology_labels), normalize_embeddings=True))
    sims = _cosine_sim_matrix(topic_vecs, onto_vecs)

    best_idx = sims.argmax(axis=1)
    best_sim = sims.max(axis=1)

    aligned = pd.DataFrame({
        "topic_id": topic_ids,
        "topic_size": info["Count"].tolist(),
        "main_label": [ontology_labels[i] for i in best_idx],
        "alignment_score": best_sim,
    }).sort_values(["main_label", "alignment_score", "topic_size"], ascending=[True, False, False])

    return aligned


def run_bertopic_validation(df: pd.DataFrame, seed_df: pd.DataFrame) -> Dict[str, Any]:
    """
    Returns aggregate artifacts only.
    SAFE_MODE: does NOT return topic Name strings.
    """
    _assert_columns(df, [ID_COL, TEXT_COL], name="df")
    _raise_if_public_run_on_text()

    ontology_labels, seed_topic_list, seed_dedup = build_seed_topic_list(seed_df)
    embedder, embeddings = compute_embeddings(df[TEXT_COL].astype(str).tolist())

    model, topics, probs, info = fit_guided_bertopic(df, seed_topic_list, embedder, embeddings)
    aligned_topics = align_topics_to_ontology_safer(model, embedder, ontology_labels)

    # Drop Name in SAFE_MODE (topic names can leak)
    if SAFE_MODE:
        topic_info = info[["Topic", "Count"]].copy()
    else:
        topic_info = info[["Topic", "Count", "Name"]].copy()

    return {
        "topic_info": topic_info,
        "aligned_topics": aligned_topics,
        "n_docs": int(len(df)),
        "n_topics": int((info["Topic"] != -1).sum()),
    }


# =====================
# III. Symptom annotation (Gemini) — safe wrappers
# =====================

try:
    from google import genai
    from google.genai import types
except Exception:
    genai = None
    types = None


def gemini_single_label(
    client,
    MODEL_ID: str,
    message: str,
    temperature: float = 0.0,
    max_tokens: int = 2500,
) -> Tuple[int, str]:
    """
    Classify a single message into categories 0–11.
    SAFE_MODE: never prints model output; truncates returned rationale.
    """
    _require_llm_enabled()
    _raise_if_public_run_on_text()

    system_instruction = (
        "These are the messages from patients sent to healthcare professionals. "
        "Rules: If a message contains both medical and non-medical issues, "
        "prioritize addressing the medical topics (Symptom Updates & Clinical Concerns)."
    )

    categories = (
        "Classify the message according to the following categories:\n"
        "1. Medication Issues\n"
        "2. Symptom Updates & Clinical Concerns\n"
        "3. Medical Equipment, Supplies, and Home Health\n"
        "4. Administrative Tasks\n"
        "5. Lab Test & Imaging\n"
        "6. Appointment Scheduling / Rescheduling / Cancelling\n"
        "7. Caregiver Support and Logistics\n"
        "8. Specialist Referral related issues\n"
        "9. General Communications (confirmation, gratitude)\n"
        "10. General Communications: non-medical/logistics\n"
        "11. General Communications (other)\n"
        "Otherwise 0.\n\n"
        "Output format:\n"
        "Classification: <0 to 11>\n"
        "Reason: <reason_text>"
    )

    full_prompt = f"{system_instruction}\n\n{categories}\nMessage:\n{message}"

    try:
        response = client.models.generate_content(
            model=MODEL_ID,
            contents=full_prompt,
            config=types.GenerateContentConfig(
                temperature=temperature,
                max_output_tokens=max_tokens,
            ),
        )

        content = extract_text_from_response(response)

        classification = 0
        reasoning = "No reason provided."

        for line in content.splitlines():
            line = line.strip()
            if line.lower().startswith("classification"):
                try:
                    classification = int(re.findall(r"\d+", line.split(":", 1)[1])[0])
                except Exception:
                    classification = 0
            elif line.lower().startswith("reason"):
                reasoning = line.split(":", 1)[1].strip()

        if classification < 0 or classification > 11:
            classification = 0

        return classification, _sanitize_reason(reasoning)

    except Exception as e:
        # Safe: do not print message or model output
        _log(f"[WARN] gemini_single_label failed: {type(e).__name__}")
        return 0, "Error processing message."


def gemini_multi_label_0_105(
    client,
    MODEL_ID: str,
    message: str,
    temperature: float = 0.0,
    max_tokens: int = 2500,
) -> Tuple[List[int], str]:
    """
    Multi-label (up to 3) in 0–105.
    SAFE_MODE: never prints model output; truncates returned rationale.
    """
    _require_llm_enabled()
    _raise_if_public_run_on_text()

    system_instruction = (
        "These are the messages from patients sent to healthcare professionals.\n"
        "Rules:\n"
        "1) If a message contains both medical and non-medical issues, prioritize medical topics.\n"
        "2) If, and only if, a message contains strictly non-medical topics, categorize it as 105.\n"
        "3) Prefer the most relevant category available. Up to 3 labels if truly necessary.\n"
        "Output format:\n"
        "Classification: <0 to 105>\n"
        "Reason: <reason_text>"
    )


    categories = "CATEGORIES: (see separate file / template)\n"

    full_prompt = f"{system_instruction}\n\n{categories}\nMessage:\n{message}"

    try:
        response = client.models.generate_content(
            model=MODEL_ID,
            contents=full_prompt,
            config=types.GenerateContentConfig(
                temperature=temperature,
                max_output_tokens=max_tokens,
            ),
        )

        content = extract_text_from_response(response)

        classifications: List[int] = [0]
        reasoning = "No reason provided."

        for line in content.splitlines():
            line = line.strip()
            if line.lower().startswith("classification"):
                nums = [int(n) for n in re.findall(r"\d+", line)]
                clean: List[int] = []
                for n in nums:
                    if 0 <= n <= 105 and n not in clean:
                        clean.append(n)
                    if len(clean) == 3:
                        break
                classifications = clean or [0]
            elif line.lower().startswith("reason"):
                reasoning = line.split(":", 1)[1].strip()

        classifications = [c for c in classifications if 0 <= c <= 105] or [0]
        return classifications, _sanitize_reason(reasoning)

    except Exception as e:
        _log(f"[WARN] gemini_multi_label_0_105 failed: {type(e).__name__}")
        return [0], "Error processing message."


# =====================
# IV. Dual ML pipeline (no path leakage; no text prints)
# =====================

if __name__ == "__main__":
    # Load tables
    messages = load_messages(MESSAGES_CSV)
    symptoms = load_symptoms(SYMPTOMS_CSV)

    # Ensure unique node IDs
    messages = messages.drop_duplicates(subset="m_id", keep="last").copy()
    symptoms = symptoms.drop_duplicates(subset="symptom_id", keep="last").copy()

    keep_msg_cols = [c for c in ["p_id", "m_id", "text", "timestamp"] if c in messages.columns]
    messages = messages[keep_msg_cols].copy()

    p_ids_from_msgs = messages["p_id"].astype(str).unique().tolist()

    # Persons: demographics + comorbids
    person_demo, demo_cols = load_person_demo(PERSON_DEMO_CSV, p_ids=p_ids_from_msgs, one_hot=True)
    comorbids, icd_cols = load_comorbids(
        COMORBID_CSV, p_ids=p_ids_from_msgs,
        min_patients_per_code=ICD_MIN_PREV, max_codes=ICD_MAX_COLS
    )

    persons = person_demo.merge(comorbids, on="p_id", how="left").fillna(0)
    person_feature_cols = demo_cols + icd_cols
    persons = persons.drop_duplicates(subset="p_id", keep="last").copy()

    # Temporal person features
    persons, person_feature_cols = add_temporal_person_features(messages, persons, person_feature_cols)

    # Build message→symptom edges
    msg_emb = sym_emb = None
    if USE_LLM_SCORES and LLM_SCORES_PATH.exists():
        print(f"[INFO] Using LLM scores at: {LLM_SCORES_PATH}")
        ms_edges = message_to_symptom_edges_from_llm(
            messages, symptoms,
            llm_scores_path=LLM_SCORES_PATH,
            score_col=LLM_SCORE_COLUMN,
            is_proba=LLM_IS_PROBA,
            min_score=EDGE_MIN_SIM,
            top_k=TOP_K_SYMPTOMS,
        )

        # Optional: precompute embeddings (only in memory)
        msg_emb = embed_texts(messages["text"].fillna("").astype(str).tolist())
        sym_emb = embed_texts(
            (symptoms["symptom_name"].astype(str) + ": " + symptoms["description"].astype(str)).tolist()
        )

        if BLEND_WITH_EMBEDDINGS:
            emb_edges, _, _ = message_to_symptom_edges_via_embeddings(
                messages, symptoms, top_k=TOP_K_SYMPTOMS, min_sim=EDGE_MIN_SIM
            )
            mix = ms_edges.merge(emb_edges, on=["m_id", "symptom_id"], how="outer", suffixes=("_llm", "_emb")).fillna(0.0)
            mix["weight"] = BLEND_ALPHA * mix["weight_llm"] + (1.0 - BLEND_ALPHA) * mix["weight_emb"]
            ms_edges = mix[["m_id", "symptom_id", "weight"]]
    else:
        print("[INFO] LLM scores not found → using embeddings only.")
        ms_edges, msg_emb, sym_emb = message_to_symptom_edges_via_embeddings(
            messages, symptoms, top_k=TOP_K_SYMPTOMS, min_sim=EDGE_MIN_SIM
        )

    # Keep edges that point to surviving nodes (AFTER ms_edges exists)
    ms_edges = ms_edges[
        ms_edges["m_id"].astype(str).isin(messages["m_id"].astype(str)) &
        ms_edges["symptom_id"].astype(str).isin(symptoms["symptom_id"].astype(str))
    ].copy()

    # Negation down-weighting
    ms_edges = apply_negation_downweight(messages, symptoms, ms_edges)

    # Basic aggregate diagnostics (safe)
    print(f"[INFO] messages: n={len(messages):,}")
    print(f"[INFO] symptoms: n={len(symptoms):,}")
    print(f"[INFO] persons:  n={persons['p_id'].nunique():,}")
    print(f"[INFO] ms_edges: n={len(ms_edges):,} (avg edges/msg ≈ {len(ms_edges)/max(len(messages),1):.2f})")

    # If PyG unavailable, do EN baseline in-memory + print only
    if not _PYG_OK:
        print("[WARN] torch_geometric not installed → skipping GNN.")
        X_sym = (ms_edges.merge(messages[["m_id", "p_id"]], on="m_id", how="left")
                        .pivot_table(index="p_id", columns="symptom_id", values="weight", aggfunc="sum", fill_value=0.0))
        print(f"[INFO] Built person×symptom matrix in-memory: shape={X_sym.shape} (NOT EXPORTED)")

        if RUN_EN_BASELINE_QUICK:
            coefs, cv_auc = elastic_net_baseline(persons, X_sym, l1_ratio=EN_L1_RATIO, C=EN_C, cv=EN_CV)
            print(f"Elastic Net CV AUC (quick): {cv_auc:.3f}")

            top = coefs.abs().sort_values(ascending=False).head(20).rename("abs_coef").reset_index()
            top = top.rename(columns={"index": "symptom_id"}).merge(
                symptoms[["symptom_id", "symptom_name"]], on="symptom_id", how="left"
            )
            print("\nTop 20 EN |coef| symptoms (aggregate):")
            print(top[["symptom_id", "symptom_name", "abs_coef"]].to_string(index=False))
        else:
            print("[INFO] Quick EN baseline skipped.")
        raise SystemExit(0)

    # Build base graph
    hetero, id_maps = build_hetero_graph(
        messages, ms_edges, symptoms,
        persons_df=persons, person_feature_cols=person_feature_cols,
        msg_emb=msg_emb, sym_emb=sym_emb,
        add_message_knn=True, k_msg_sim=5, add_symptom_cooc=True,
    )

    # Build dx pairs from comorbid indicator matrix
    if icd_cols:
        dx_pairs = (persons[["p_id"] + icd_cols]
                    .melt(id_vars=["p_id"], var_name="dx_id", value_name="has_code"))
        dx_pairs = dx_pairs.loc[dx_pairs["has_code"] > 0, ["p_id", "dx_id"]].copy()
        hetero, id_maps = add_dx_to_graph(hetero, dx_pairs, id_maps)

    # Person similarity edges
    hetero = add_person_similarity_edges_safe(hetero, persons, person_feature_cols, id_maps, k=5, min_sim=0.25)

    assert_graph_ok(hetero, id_maps)

    # Train GNN
    model, z, person_probs = train_gnn(hetero, persons, id_maps)

    # PRINT-ONLY person_probs aggregate summary (no IDs)
    person_probs = np.asarray(person_probs, dtype=float)
    print(f"[INFO] GNN person_probs (NOT EXPORTED): n={person_probs.size:,} "
          f"mean={person_probs.mean():.4f} sd={person_probs.std():.4f} "
          f"p50={np.quantile(person_probs, 0.50):.4f} p95={np.quantile(person_probs, 0.95):.4f}")

    # Symptom importance + EN baseline (computed in-memory)
    sym_rank = rank_symptoms(messages, ms_edges, hetero, z, id_maps).merge(symptoms, on="symptom_id", how="left")
    X_sym = build_person_symptom_matrix(messages, ms_edges, id_maps)

    if RUN_EN_BASELINE_QUICK:
        coefs, cv_auc = elastic_net_baseline(persons, X_sym, l1_ratio=EN_L1_RATIO, C=EN_C, cv=EN_CV)
        print(f"Elastic Net CV AUC (quick): {cv_auc:.3f}")

        coef_df = coefs.rename("elasticnet_coef").reset_index().rename(columns={"index": "symptom_id"})
        merged = sym_rank.merge(coef_df, on="symptom_id", how="left")
        merged["elasticnet_coef_abs"] = merged["elasticnet_coef"].abs()

        pipe_en, X_en, y_en = elastic_net_fit(persons, X_sym, l1_ratio=EN_L1_RATIO, C=EN_C)
        perm_imp = elastic_net_permutation_importance(
            pipe_en, X_en, y_en, feature_names=X_sym.columns.tolist(), n_repeats=10
        )
        perm_df = perm_imp.rename("en_perm_importance").reset_index().rename(columns={"index": "symptom_id"})
        merged = merged.merge(perm_df, on="symptom_id", how="left")

        def _z(col):
            arr = merged[col].fillna(0.0).values
            mu, sd = arr.mean(), arr.std() + 1e-9
            return (arr - mu) / sd

        merged["event_assoc_score"] = _z("en_perm_importance")
    else:
        merged = sym_rank.copy()
        for col in ["elasticnet_coef", "elasticnet_coef_abs", "en_perm_importance", "event_assoc_score"]:
            merged[col] = np.nan
        print("[INFO] Quick EN baseline skipped.")

    # Print top symptoms (aggregate only; no text, no IDs beyond symptom_id)
    cols_show = [
        "symptom_id", "symptom_name",
        "coverage", "coverage_recency",
        "importance_score",
        "en_perm_importance",
        "event_assoc_score",
        "elasticnet_coef"
    ]
    cols_show = [c for c in cols_show if c in merged.columns]

    top20 = (merged.sort_values("importance_score", ascending=False)
                   .loc[:, cols_show]
                   .head(20)
                   .fillna(0.0))

