In [4]:
import os, re
import pandas as pd
import numpy as np
from pathlib import Path

In [None]:
# ================= TRIAGE FROM CSVs (CONSOLIDATED ONLY, STRING-SAFE IDS) =================
# Loads consolidated all_entities.csv, builds triage features, runs rules, and saves alerts.
# No discovery patterns and no processed-duration CSVs needed.


# ---------- 0) Configure file paths ----------
ROOT = Path.cwd().parent
ALL_ENTITIES_CSV = ROOT / "data" / "clean" / "all_entities.csv"

# ---------- 1) Load entities (IDs as strings; never from index) ----------
def load_entities_from_csv(all_entities_csv):
    diag = []

    def _read_csv_str(path):
        return pd.read_csv(path, low_memory=False, dtype={"row_idx":"string", "idx":"string"})

    def _normalize(df, name):
        df = df.copy()

        # Drop any accidental saved index columns 
        drop_candidates = [c for c in df.columns if c.lower().startswith("unnamed:")]
        if drop_candidates:
            df.drop(columns=drop_candidates, inplace=True, errors="ignore")

        # Normalize id column to 'row_idx' (as string)
        if "row_idx" in df.columns:
            df["row_idx"] = df["row_idx"].astype("string").str.strip()
        elif "idx" in df.columns:
            df.rename(columns={"idx":"row_idx"}, inplace=True)
            df["row_idx"] = df["row_idx"].astype("string").str.strip()
        else:
            for alt in ["row", "id", "record_id", "source_idx", "orig_idx"]:
                if alt in df.columns:
                    df.rename(columns={alt:"row_idx"}, inplace=True)
                    df["row_idx"] = df["row_idx"].astype("string").str.strip()
                    break

        keep = [c for c in ["row_idx","label","text","category","original_text","source"] if c in df.columns]
        missing = set(["row_idx","label","text"]) - set(keep)
        if missing:
            diag.append((str(name), False, f"Missing required columns: {missing}", None))
            return None

        df = df[keep]
        df["label"] = df["label"].astype(str)
        df["text"]  = df["text"].astype(str)
        if "category" in df.columns:
            df["category"] = df["category"].astype(str)
        return df

    if all_entities_csv is None or not os.path.exists(all_entities_csv):
        print("\n[TRIAGE] all_entities.csv not found. Aborting triage.")
        return pd.DataFrame()

    try:
        df = _read_csv_str(all_entities_csv)
        entities_all = _normalize(df, all_entities_csv)
        if entities_all is None or entities_all.empty:
            print("\n[TRIAGE] all_entities.csv empty or invalid schema. Aborting triage.")
            return pd.DataFrame()
        print(f"\n[TRIAGE] Loaded consolidated: {all_entities_csv} shape={entities_all.shape}")
    except Exception as e:
        print(f"\n[TRIAGE] ERROR reading {all_entities_csv}: {e}")
        return pd.DataFrame()

    entities_all["row_idx"]   = entities_all["row_idx"].astype("string")
    entities_all["text_norm"] = entities_all["text"].str.lower()

    # Sanity: warn if any suspiciously short ids (likely positional)
    id_len = entities_all["row_idx"].str.len()
    short = (id_len <= 3).sum()
    if short:
        print(f"[TRIAGE][WARN] Found {short} rows with very short row_idx (<=3 chars). "
              f"Examples: {entities_all.loc[id_len <= 3, 'row_idx'].dropna().unique()[:10]}")
    return entities_all

entities_all = load_entities_from_csv(ALL_ENTITIES_CSV)
if entities_all.empty:
    raise SystemExit

# ---------- 2) Build features from entities (no external duration files) ----------
def build_features(entities_all):
    SEVERITY_MAP = {
        'mild': 1, 'minimal': 1, 'slight': 1, 'minor': 1,
        'moderate': 2, 'medium': 2, 'intermediate': 2,
        'severe': 3, 'significant': 3, 'marked': 3, 'extensive': 3, 'extreme': 3, 'critical': 4
    }

    CARDIAC_TERMS = set("""
        chest pain angina cardiac heart coronary mi myocardial ischemia infarction aortic valve atrium ventricle
        palpitations syncope dyspnea shortness of breath sob
    """.split())
    NEURO_TERMS = set("""
        stroke cva weakness paralysis paresthesia numbness tingling aphasia dysarthria seizure syncope facial droop
        confusion altered mental status headache thunderclap
    """.split())
    INFECTION_TERMS = set("""
        sepsis septic fever rigors chills endocarditis abscess cellulitis pneumonia purulent
    """.split())
    CRITICAL_FINDINGS = set("""
        hemorrhage aneurysm occlusion dissection pulmonary embolism pe deep vein thrombosis dvt pneumothorax
    """.split())

    def contains_any(term_set, keywords):
        return any(k in term_set for k in keywords)

    agg = entities_all.groupby('row_idx', dropna=False).agg(
        terms=('text_norm', lambda s: set(s)),
        labels=('label',    lambda s: set(s.astype(str))),
        categories=('category', lambda s: set(s.astype(str)) if 'category' in entities_all.columns else set())
    ).reset_index()

    # Severity score
    sever_df = entities_all[entities_all['label'] == 'SEVERITY'].copy()
    if not sever_df.empty:
        sever_df['score'] = sever_df['text_norm'].map(
            lambda t: max([SEVERITY_MAP.get(t, 0)] + [SEVERITY_MAP.get(w, 0) for w in re.findall(r'[a-z]+', t)])
        )
        max_sev = sever_df.groupby('row_idx', dropna=False)['score'].max().rename('severity_score')
    else:
        max_sev = pd.Series(dtype=float, name='severity_score')

    # Flags
    has_bilateral = (entities_all[entities_all['text_norm'].str.contains(r'\bbilateral\b', na=False)]
                     .groupby('row_idx', dropna=False).size().rename('has_bilateral').astype(bool))
    has_laterality = (entities_all[entities_all['label'] == 'LATERALITY']
                      .groupby('row_idx', dropna=False).size().rename('has_laterality').astype(bool))

    tmp = entities_all[entities_all['label'] == 'TEMPORAL_PATTERN']
    has_acute = (tmp[tmp['text_norm'].str.contains(r'\b(acute|sudden|abrupt|rapid)\b', na=False)]
                 .groupby('row_idx', dropna=False).size().rename('has_acute').astype(bool)) if not tmp.empty else pd.Series(dtype=bool, name='has_acute')
    has_chronic = (tmp[tmp['text_norm'].str.contains(r'\b(chronic|persistent|ongoing|continuous|months|years)\b', na=False)]
                   .groupby('row_idx', dropna=False).size().rename('has_chronic').astype(bool)) if not tmp.empty else pd.Series(dtype=bool, name='has_chronic')

    row_terms = entities_all.groupby('row_idx', dropna=False)['text_norm'].apply(set)
    cardiac_flag = row_terms.apply(lambda s: contains_any(s, CARDIAC_TERMS)).rename('has_cardiac_terms')
    neuro_flag   = row_terms.apply(lambda s: contains_any(s, NEURO_TERMS)).rename('has_neuro_terms')
    inf_flag     = row_terms.apply(lambda s: contains_any(s, INFECTION_TERMS)).rename('has_infection_terms')
    critical_finding_flag = row_terms.apply(lambda s: contains_any(s, CRITICAL_FINDINGS)).rename('has_critical_finding_term')

    feats = (agg
             .merge(max_sev, on='row_idx', how='left')
             .merge(has_bilateral, on='row_idx', how='left')
             .merge(has_laterality, on='row_idx', how='left')
             .merge(has_acute, on='row_idx', how='left')
             .merge(has_chronic, on='row_idx', how='left')
             .merge(cardiac_flag, on='row_idx', how='left')
             .merge(neuro_flag, on='row_idx', how='left')
             .merge(inf_flag, on='row_idx', how='left')
             .merge(critical_finding_flag, on='row_idx', how='left')
            )

    # No external duration files -> set column but leave NaN
    feats['duration_any_days'] = np.nan

    for col in ['has_bilateral','has_laterality','has_acute','has_chronic',
                'has_cardiac_terms','has_neuro_terms','has_infection_terms','has_critical_finding_term']:
        if col in feats.columns:
            feats[col] = feats[col].fillna(False)

    feats['severity_score'] = feats['severity_score'].fillna(0.0).astype(float)
    feats['row_idx'] = feats['row_idx'].astype("string")
    return feats

features_df = build_features(entities_all)

# ---------- 3) Rules & scoring ----------
def _ev(row, pick=5):
    return ", ".join(sorted(list(row.get('terms', set())))[:pick])

def rule_R1_cardiac_red_flags(row):
    fired = bool(row.get('has_cardiac_terms')) and (bool(row.get('has_acute')) or row.get('severity_score',0) >= 3)
    return (True, "Cardiac red flags (acute or severe). Consider immediate evaluation.", _ev(row)) if fired else (False, "", "")

def rule_R2_neuro_deficit(row):
    fired = bool(row.get('has_neuro_terms')) and (bool(row.get('has_laterality')) or row.get('severity_score',0) >= 3 or bool(row.get('has_acute')))
    return (True, "Neurological deficit indicators detected (focal/severe/acute).", _ev(row)) if fired else (False, "", "")

def rule_R3_critical_imaging_terms(row):
    fired = bool(row.get('has_critical_finding_term'))
    return (True, "Critical finding term present (e.g., hemorrhage/occlusion/PE/DVT).", _ev(row)) if fired else (False, "", "")

def rule_R4_sepsis_or_severe_infection(row):
    dur = row.get('duration_any_days')
    dur_ok = pd.notna(dur) and dur >= 3
    fired = bool(row.get('has_infection_terms')) and (row.get('severity_score',0) >= 3 or dur_ok)
    return (True, "Infection red flags (severe/prolonged).", _ev(row)) if fired else (False, "", "")

def rule_R5_prolonged_severe_symptoms(row):
    dur = row.get('duration_any_days')
    dur_ok = pd.notna(dur) and dur >= 14
    fired = dur_ok and (row.get('severity_score',0) >= 2)
    return (True, "Prolonged and moderate-to-severe symptoms.", _ev(row)) if fired else (False, "", "")

def rule_R6_bilateral_with_severity(row):
    fired = bool(row.get('has_bilateral')) and (row.get('severity_score',0) >= 2)
    return (True, "Bilateral involvement with moderate+ severity.", _ev(row)) if fired else (False, "", "")

RULES = [
    ("R1", "CRITICAL", rule_R1_cardiac_red_flags),
    ("R2", "CRITICAL", rule_R2_neuro_deficit),
    ("R3", "CRITICAL", rule_R3_critical_imaging_terms),
    ("R4", "URGENT",   rule_R4_sepsis_or_severe_infection),
    ("R5", "URGENT",   rule_R5_prolonged_severe_symptoms),
    ("R6", "URGENT",   rule_R6_bilateral_with_severity),
]
PRIORITY_SCORE = {"CRITICAL": 3, "URGENT": 2, "ROUTINE": 1}

def apply_rules(features_df):
    if features_df.empty:
        return pd.DataFrame(), pd.DataFrame()
    alerts = []
    for _, r in features_df.iterrows():
        for rule_id, priority, fn in RULES:
            fired, msg, ev = fn(r)
            if fired:
                alerts.append({
                    "row_idx": str(r["row_idx"]),
                    "rule_id": rule_id,
                    "priority": priority,
                    "priority_score": PRIORITY_SCORE[priority],
                    "message": msg,
                    "evidence": ev
                })
    if not alerts:
        return pd.DataFrame(), pd.DataFrame()
    alerts_df = pd.DataFrame(alerts).drop_duplicates(subset=["row_idx","rule_id","message"])
    top = (alerts_df.sort_values(["row_idx","priority_score"], ascending=[True, False])
                    .groupby("row_idx").head(1).reset_index(drop=True))
    return alerts_df, top

alerts_df, best_alert_per_row = apply_rules(features_df)

# ---------- 4) Save & print ----------
if not alerts_df.empty:
    alerts_df.to_csv("triage_alerts_all.csv", index=False)
    best_alert_per_row.to_csv("triage_alerts_top.csv", index=False)
    print("\n[TRIAGE] Saved triage_alerts_all.csv and triage_alerts_top.csv")

    # Quick sanity output
    print("\n[TRIAGE] row_idx length distribution (first 10 bins):")
    print(best_alert_per_row["row_idx"].astype(str).str.len().value_counts().sort_index().head(10))
    print("\n=== Top alert per row (sample) ===")
    print(best_alert_per_row.head(15))
else:
    print("\n[TRIAGE] No alerts fired.")


[TRIAGE] Loaded consolidated: /Users/anhche/Desktop/Project/Clinical_NLP/Clinical_NLP_Project/data/clean/all_entities.csv shape=(2113057, 6)
[TRIAGE][WARN] Found 11581 rows with very short row_idx (<=3 chars). Examples: <StringArray>
['958', '431', '558', '499', '390', '643', '629', '781', '637', '406']
Length: 10, dtype: string


  has_acute = (tmp[tmp['text_norm'].str.contains(r'\b(acute|sudden|abrupt|rapid)\b', na=False)]
  has_chronic = (tmp[tmp['text_norm'].str.contains(r'\b(chronic|persistent|ongoing|continuous|months|years)\b', na=False)]
  feats[col] = feats[col].fillna(False)



[TRIAGE] Saved triage_alerts_all.csv and triage_alerts_top.csv

[TRIAGE] row_idx length distribution (first 10 bins):
row_idx
2       3
3      36
4     469
5    4696
6    5328
Name: count, dtype: int64

=== Top alert per row (sample) ===
   row_idx rule_id  priority  priority_score  \
0   100017      R2  CRITICAL               3   
1   100018      R1  CRITICAL               3   
2    10003      R2  CRITICAL               3   
3   100052      R3  CRITICAL               3   
4   100070      R1  CRITICAL               3   
5   100108      R2  CRITICAL               3   
6   100111      R3  CRITICAL               3   
7   100134      R3  CRITICAL               3   
8    10017      R3  CRITICAL               3   
9   100191      R3  CRITICAL               3   
10  100193      R2  CRITICAL               3   
11  100204      R2  CRITICAL               3   
12  100253      R1  CRITICAL               3   
13  100265      R1  CRITICAL               3   
14  100281      R1  CRITICAL             

In [7]:
alerts_df

Unnamed: 0,row_idx,rule_id,priority,priority_score,message,evidence
0,100017,R2,CRITICAL,3,Neurological deficit indicators detected (foca...,"42, 42 y/o, 42 y/o female, 42 y/o female patie..."
1,100018,R1,CRITICAL,3,Cardiac red flags (acute or severe). Consider ...,"10 ml, 157 mg, 57, 57 y/o, 57 y/o female"
2,10003,R2,CRITICAL,3,Neurological deficit indicators detected (foca...,"58, 58 y/o, 58 y/o female, 58 y/o female patie..."
3,10003,R3,CRITICAL,3,"Critical finding term present (e.g., hemorrhag...","58, 58 y/o, 58 y/o female, 58 y/o female patie..."
4,100052,R3,CRITICAL,3,"Critical finding term present (e.g., hemorrhag...","24, 24 y/o, 24 y/o male, 24 y/o male patient, ..."
...,...,...,...,...,...,...
13878,99909,R2,CRITICAL,3,Neurological deficit indicators detected (foca...,"0.7 cm, 1.4 cm, 54, 54 y/o, 54 y/o male"
13879,99926,R1,CRITICAL,3,Cardiac red flags (acute or severe). Consider ...,"25, 25 y/o, 25 y/o female, 25 y/o female patie..."
13880,99970,R1,CRITICAL,3,Cardiac red flags (acute or severe). Consider ...,"40, 40 y/o, 40 y/o female, 40 y/o female patie..."
13881,99970,R6,URGENT,2,Bilateral involvement with moderate+ severity.,"40, 40 y/o, 40 y/o female, 40 y/o female patie..."
