In [2]:
import os, re
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
from typing import Optional

In [3]:
"""
TRIAGE ALERTS FROM CONSOLIDATED ENTITIES (ESI-BASED)
==================================================

USAGE
-----
- Run from your project repo (or adjust ROOT).
- Expects a single consolidated CSV at: data/clean/all_entities.csv
  Required columns: row_idx|idx (one of these), label, text
  Optional columns: category, original_text, source
- Saves outputs to: artifacts/triage/
  - triage_alerts_all.csv  : all fired rules per row
  - triage_alerts_top.csv  : top rule(s) per row
  - triage_rule_diagnostics.csv : rule coverage counts + priority mix

WHAT THIS DOES
--------------
1) Loads entities (IDs always treated as strings; never positional index).
2) Builds triage features:
   - Severity mapping (including "10/10" and strong adjectives)
   - Acute/chronic flags from TEMPORAL_PATTERN
   - Laterality flags and bilateral detection
   - Domain lexicons (cardiac/neuro/infection/critical finding)
   - Duration parsing from free text (no external files)
   - Negation / uncertainty dampeners (e.g., "denies X")
   - Proximity signals (laterality near neuro deficits)
   - Optional category/source weighting if `category` exists
3) Applies ESI-based rules (ESI 1-5) using the above features
4) Arbitration: accumulates weighted rule scores per row and assigns
   final ESI level via thresholds
5) Writes CSVs and prints quick sanity diagnostics.

ESI LEVELS
----------
- ESI 1 (Immediate): Life-saving intervention needed
- ESI 2 (Emergency): High risk of deterioration, time-critical
- ESI 3 (Urgent): Stable but needs prompt assessment
- ESI 4 (Non-urgent): Very stable, can wait extended period
- ESI 5 (Minor): No acute threat, can wait extended period
"""


# =========================
# CONFIGURATION
# =========================
ROOT = Path.cwd().parent
DATA = ROOT / "data" / "clean"
ARTF = ROOT / "artifacts" / "triage"
ARTF.mkdir(parents=True, exist_ok=True)

# Input
ALL_ENTITIES_CSV = DATA / "all_entities.csv"

# Output
TRIAGE_ALL_CSV = ARTF / "triage_alerts_all.csv"
TRIAGE_TOP_CSV = ARTF / "triage_alerts_top.csv"
TRIAGE_DIAG_CSV = ARTF / "triage_rule_diagnostics.csv"

# ESI level thresholds for accumulated scores
ESI_THRESHOLDS = {
    'ESI_1': 4.0,   # Immediate life-saving intervention
    'ESI_2': 2.5,   # Emergency, high risk
    'ESI_3': 1.5,   # Urgent, needs prompt assessment
    'ESI_4': 0.8,   # Non-urgent, can wait
    # ESI_5 is default (anything below ESI_4 threshold)
}

# Additive weights per rule type
RULE_WEIGHTS = {
    'LIFE_THREAT': 2.0,      # ESI 1 rules
    'HIGH_RISK': 1.5,        # ESI 2 rules
    'URGENT': 1.0,           # ESI 3 rules
    'NON_URGENT': 0.6,       # ESI 4 rules
    'MINOR': 0.3,            # ESI 5 rules
}

# Category/source weighting (applies multiplicatively on each fired rule)
CATEGORY_WEIGHT = {
    'imaging':       1.30,
    'assessment':    1.20,
    'hpi':           1.00,
    'problem_list':  0.90,
}

# Negation / uncertainty phrases (affects negated_any())
NEG_TRIG = r'\b(no|denies|without|not|negative for|rule\s*out|unlikely|resolved|improved)\b'

# Severity boost patterns
SEVERITY_STRONG_WORDS = r'\b(excruciating|worst|life[- ]threatening|severely|unbearable|critical)\b'
PAIN_10OF10_MIN = 8  # treat >=8/10 as severe

# Proximity windows
PROXIMITY_WINDOW = 6  # words

# ESI -> numeric helper for sorting
ESI_SCORE = {"ESI_1": 5, "ESI_2": 4, "ESI_3": 3, "ESI_4": 2, "ESI_5": 1}

# =========================
# LEXICONS — TUNE TERMS
# =========================
SEVERITY_MAP = {
    'mild': 1, 'minimal': 1, 'slight': 1, 'minor': 1,
    'moderate': 2, 'medium': 2, 'intermediate': 2,
    'severe': 3, 'significant': 3, 'marked': 3, 'extensive': 3, 
    'extreme': 4, 'critical': 4, 'life threatening': 4, 'unbearable': 4
}

# ESI 1 - Immediate life threats
LIFE_THREAT_TERMS = set("""
    arrest cardiac arrest respiratory arrest apnea agonal unresponsive massive hemorrhage 
    exsanguinating airway compromise intubation
""".split())

# ESI 2 - High risk conditions
CARDIAC_CRITICAL = set("""
    chest pain angina cardiac heart coronary mi myocardial ischemia infarction stemi nstemi
    aortic dissection tamponade cardiogenic shock
""".split())

NEURO_CRITICAL = set("""
    stroke cva tpa thrombolysis weakness paralysis aphasia dysarthria seizure status epilepticus
    altered mental status ams confusion thunderclap subarachnoid hemorrhage
""".split())

RESPIRATORY_CRITICAL = set("""
    dyspnea shortness breath sob respiratory distress hypoxia cyanosis stridor
""".split())

# ESI 3 - Urgent but stable
MODERATE_TERMS = set("""
    abdominal pain syncope fracture dislocation laceration burn dehydration vomiting
    fever infection cellulitis pneumonia uti pyelonephritis
""".split())

# ESI 4/5 - Non-urgent/minor
MINOR_TERMS = set("""
    medication refill cold symptoms cough congestion rash itch constipation follow up
    suture removal wound check
""".split())

# Critical findings that suggest higher ESI levels
CRITICAL_FINDINGS = set("""
    hemorrhage aneurysm occlusion dissection pulmonary embolism pe deep vein thrombosis dvt 
    pneumothorax tension pneumothorax mass effect herniation shock sepsis septic
""".split())

LATERALITY_WORDS = {'left','right','bilateral'}
NEURO_FOCAL_WORDS = {'weakness','paralysis','numbness','tingling','aphasia','dysarthria','hemianopsia','facial'}

# =========================
# 1) LOADING & NORMALIZATION
# =========================
def load_entities_from_csv(all_entities_csv: Path) -> pd.DataFrame:
    def _read_csv_str(path):
        return pd.read_csv(path, low_memory=False, dtype={"row_idx":"string", "idx":"string"})

    def _normalize(df: pd.DataFrame, name: Path) -> Optional[pd.DataFrame]:
        df = df.copy()

        # Drop accidental index columns
        drop_candidates = [c for c in df.columns if c.lower().startswith("unnamed:")]
        if drop_candidates:
            df.drop(columns=drop_candidates, inplace=True, errors="ignore")

        # Normalize id column to 'row_idx' (as string)
        if "row_idx" in df.columns:
            df["row_idx"] = df["row_idx"].astype("string").str.strip()
        elif "idx" in df.columns:
            df.rename(columns={"idx":"row_idx"}, inplace=True)
            df["row_idx"] = df["row_idx"].astype("string").str.strip()
        else:
            for alt in ["row", "id", "record_id", "source_idx", "orig_idx"]:
                if alt in df.columns:
                    df.rename(columns={alt:"row_idx"}, inplace=True)
                    df["row_idx"] = df["row_idx"].astype("string").str.strip()
                    break

        keep = [c for c in ["row_idx","label","text","category","original_text","source"] if c in df.columns]
        missing = set(["row_idx","label","text"]) - set(keep)
        if missing:
            print(f"[TRIAGE][ERR] {name}: missing required columns: {missing}")
            return None

        df = df[keep]
        df["label"] = df["label"].astype(str)
        df["text"]  = df["text"].astype(str)
        if "category" in df.columns:
            df["category"] = df["category"].astype(str)
        return df

    if all_entities_csv is None or not os.path.exists(all_entities_csv):
        print("\n[TRIAGE] all_entities.csv not found. Aborting triage.")
        return pd.DataFrame()

    try:
        df = _read_csv_str(all_entities_csv)
        entities_all = _normalize(df, all_entities_csv)
        if entities_all is None or entities_all.empty:
            print("\n[TRIAGE] all_entities.csv empty or invalid schema. Aborting triage.")
            return pd.DataFrame()
        print(f"\n[TRIAGE] Loaded consolidated: {all_entities_csv} shape={entities_all.shape}")
    except Exception as e:
        print(f"\n[TRIAGE] ERROR reading {all_entities_csv}: {e}")
        return pd.DataFrame()

    entities_all["row_idx"]   = entities_all["row_idx"].astype("string")
    entities_all["text_norm"] = entities_all["text"].str.lower()

    # Sanity check for short IDs
    id_len = entities_all["row_idx"].str.len()
    short = (id_len <= 3).sum()
    if short:
        print(f"[TRIAGE][WARN] Found {short} rows with very short row_idx (<=3 chars). "
              f"Examples: {entities_all.loc[id_len <= 3, 'row_idx'].dropna().unique()[:10]}")
    return entities_all

# =========================
# 2) FEATURE BUILDING
# =========================
def build_features(entities_all: pd.DataFrame) -> pd.DataFrame:
    # Aggregate base containers
    agg = entities_all.groupby('row_idx', dropna=False).agg(
        terms=('text_norm', lambda s: set(s)),
        labels=('label',    lambda s: set(s.astype(str))),
        categories=('category', lambda s: set(s.astype(str)) if 'category' in entities_all.columns else set())
    ).reset_index()

    # Full row text for regex-based features
    row_text_full = entities_all.groupby('row_idx', dropna=False)['text_norm'] \
                                .apply(lambda s: " ".join(s)).rename('row_text_full')

    # Severity score from SEVERITY label + boosters
    sever_df = entities_all[entities_all['label'] == 'SEVERITY'].copy()

    def _base_sev_score(t: str) -> int:
        base = max([SEVERITY_MAP.get(t, 0)] + [SEVERITY_MAP.get(w, 0) for w in re.findall(r'[a-z]+', t)])
        if re.search(SEVERITY_STRONG_WORDS, t):
            base = max(base, 4)
        m = re.search(r'(\d{1,2})\s*/\s*10', t)
        if m and int(m.group(1)) >= PAIN_10OF10_MIN:
            base = max(base, 4)
        return base

    if not sever_df.empty:
        sever_df['score'] = sever_df['text_norm'].map(_base_sev_score)
        max_sev = sever_df.groupby('row_idx', dropna=False)['score'].max().rename('severity_score')
    else:
        max_sev = pd.Series(dtype=float, name='severity_score')

    # Laterality flags
    has_bilateral = (entities_all[entities_all['text_norm'].str.contains(r'\bbilateral\b', na=False)]
                     .groupby('row_idx', dropna=False).size().rename('has_bilateral').astype(bool))
    has_laterality = (entities_all[entities_all['label'] == 'LATERALITY']
                      .groupby('row_idx', dropna=False).size().rename('has_laterality').astype(bool))

    # Temporal flags
    tmp = entities_all[entities_all['label'] == 'TEMPORAL_PATTERN']
    if not tmp.empty:
        tmp_txt = tmp['text_norm']
        has_acute = (tmp_txt.str.contains(r'\b(acute|sudden|abrupt|rapid|immediate|emergent)\b', na=False)
                     .groupby(tmp['row_idx'], dropna=False).any().rename('has_acute'))
        has_chronic = (tmp_txt.str.contains(r'\b(chronic|persistent|ongoing|continuous|months|years)\b', na=False)
                       .groupby(tmp['row_idx'], dropna=False).any().rename('has_chronic'))
    else:
        has_acute   = pd.Series(dtype=bool, name='has_acute')
        has_chronic = pd.Series(dtype=bool, name='has_chronic')

    # Domain term flags for ESI levels
    row_terms = entities_all.groupby('row_idx', dropna=False)['text_norm'].apply(set)
    
    # ESI 1 - Life threats
    life_threat_flag = row_terms.apply(lambda s: any(t in s for t in LIFE_THREAT_TERMS)).rename('has_life_threat')
    
    # ESI 2 - High risk
    cardiac_crit_flag = row_terms.apply(lambda s: any(t in s for t in CARDIAC_CRITICAL)).rename('has_cardiac_critical')
    neuro_crit_flag = row_terms.apply(lambda s: any(t in s for t in NEURO_CRITICAL)).rename('has_neuro_critical')
    resp_crit_flag = row_terms.apply(lambda s: any(t in s for t in RESPIRATORY_CRITICAL)).rename('has_respiratory_critical')
    
    # ESI 3 - Moderate/urgent
    moderate_flag = row_terms.apply(lambda s: any(t in s for t in MODERATE_TERMS)).rename('has_moderate_terms')
    
    # ESI 4/5 - Minor
    minor_flag = row_terms.apply(lambda s: any(t in s for t in MINOR_TERMS)).rename('has_minor_terms')
    
    # Critical findings
    critical_finding_flag = row_terms.apply(lambda s: any(t in s for t in CRITICAL_FINDINGS)).rename('has_critical_finding')

    # Duration parsing
    DUR_PATS = [
        (r'(\d+)\s*day', 1), (r'(\d+)\s*week', 7), (r'(\d+)\s*month', 30),
        (r'(\d+)\s*d\b', 1), (r'(\d+)\s*w\b', 7), (r'(\d+)\s*m(o|th)\b', 30),
    ]
    def parse_duration_days(t: str) -> float:
        best = np.nan
        for pat, mult in DUR_PATS:
            m = re.search(pat, t)
            if m:
                val = float(m.group(1)) * mult
                best = val if np.isnan(best) else max(best, val)
        return best
    duration_any_days = row_text_full.map(parse_duration_days).rename('duration_any_days')

    # Negation detection
    def negated_any(text: str, terms: set[str], window: int = 3) -> bool:
        for term in terms:
            if re.search(fr'{NEG_TRIG}(?:\W+\w+){{0,{window}}}\b{re.escape(term)}\b', text):
                return True
        return False

    neg_life_threat = row_text_full.map(lambda t: negated_any(t, LIFE_THREAT_TERMS)).rename('neg_life_threat')
    neg_cardiac = row_text_full.map(lambda t: negated_any(t, CARDIAC_CRITICAL)).rename('neg_cardiac_critical')
    neg_neuro = row_text_full.map(lambda t: negated_any(t, NEURO_CRITICAL)).rename('neg_neuro_critical')
    neg_resp = row_text_full.map(lambda t: negated_any(t, RESPIRATORY_CRITICAL)).rename('neg_respiratory_critical')

    # Proximity: laterality near neuro terms
    def proximity(text: str, left: set[str], right: set[str], window: int = PROXIMITY_WINDOW) -> bool:
        words = re.findall(r'[a-z]+', text.lower())
        pos = {}
        for i, w in enumerate(words):
            pos.setdefault(w, []).append(i)
        for l in left:
            for i in pos.get(l, []):
                for r in right:
                    if any(0 < abs(i-j) <= window for j in pos.get(r, [])):
                        return True
        return False
    lat_neuro_close = row_text_full.map(lambda t: proximity(t, LATERALITY_WORDS, NEURO_FOCAL_WORDS)) \
                                   .rename('lat_neuro_close')

    # Category weight
    if 'category' in entities_all.columns:
        cat_weight = entities_all.groupby('row_idx', dropna=False)['category'].apply(
            lambda s: max(CATEGORY_WEIGHT.get(c.lower(), 1.0) for c in set(map(str, s)))
        ).rename('category_weight')
    else:
        cat_weight = pd.Series(1.0, index=agg['row_idx']).rename('category_weight')

    # Assemble features
    feats = (agg
             .merge(max_sev, on='row_idx', how='left')
             .merge(has_bilateral, on='row_idx', how='left')
             .merge(has_laterality, on='row_idx', how='left')
             .merge(has_acute, on='row_idx', how='left')
             .merge(has_chronic, on='row_idx', how='left')
             .merge(life_threat_flag, on='row_idx', how='left')
             .merge(cardiac_crit_flag, on='row_idx', how='left')
             .merge(neuro_crit_flag, on='row_idx', how='left')
             .merge(resp_crit_flag, on='row_idx', how='left')
             .merge(moderate_flag, on='row_idx', how='left')
             .merge(minor_flag, on='row_idx', how='left')
             .merge(critical_finding_flag, on='row_idx', how='left')
             .merge(duration_any_days, on='row_idx', how='left')
             .merge(neg_life_threat, on='row_idx', how='left')
             .merge(neg_cardiac, on='row_idx', how='left')
             .merge(neg_neuro, on='row_idx', how='left')
             .merge(neg_resp, on='row_idx', how='left')
             .merge(lat_neuro_close, on='row_idx', how='left')
             .merge(cat_weight, on='row_idx', how='left')
             )

    # Fill boolean columns
    bool_cols = ['has_bilateral','has_laterality','has_acute','has_chronic',
                 'has_life_threat','has_cardiac_critical','has_neuro_critical',
                 'has_respiratory_critical','has_moderate_terms','has_minor_terms',
                 'has_critical_finding','neg_life_threat','neg_cardiac_critical',
                 'neg_neuro_critical','neg_respiratory_critical','lat_neuro_close']
    
    for col in bool_cols:
        if col in feats.columns:
            feats[col] = feats[col].fillna(False)

    feats['severity_score'] = feats['severity_score'].fillna(0.0).astype(float)
    feats['row_idx'] = feats['row_idx'].astype("string")
    return feats

# =========================
# 3) ESI RULES
# =========================
def _ev(row, pick=5):
    """Compact evidence string"""
    return ", ".join(sorted(list(row.get('terms', set())))[:pick])

# ESI 1 Rules - Immediate life-saving intervention needed
def rule_ESI1_arrest(row):
    fired = bool(row.get('has_life_threat')) and not row.get('neg_life_threat', False)
    msg = "Life-threatening condition requiring immediate intervention (arrest/massive hemorrhage/airway compromise)"
    return (fired, msg, _ev(row))

def rule_ESI1_critical_finding(row):
    crit = bool(row.get('has_critical_finding'))
    acute = bool(row.get('has_acute'))
    fired = crit and acute and row.get('severity_score', 0) >= 4
    msg = "Critical finding with acute onset and extreme severity"
    return (fired, msg, _ev(row))

# ESI 2 Rules - High risk of deterioration
def rule_ESI2_cardiac_emergency(row):
    base = bool(row.get('has_cardiac_critical')) and not row.get('neg_cardiac_critical', False)
    sev = row.get('severity_score', 0) >= 3
    acute = bool(row.get('has_acute'))
    fired = base and (acute or sev)
    msg = "Cardiac emergency - high risk of deterioration (chest pain/MI/dissection)"
    return (fired, msg, _ev(row))

def rule_ESI2_neuro_emergency(row):
    base = bool(row.get('has_neuro_critical')) and not row.get('neg_neuro_critical', False)
    focal = bool(row.get('lat_neuro_close')) or bool(row.get('has_laterality'))
    acute = bool(row.get('has_acute'))
    fired = base and (focal or acute)
    msg = "Neurological emergency - time-critical (stroke/seizure/AMS)"
    return (fired, msg, _ev(row))

def rule_ESI2_respiratory_emergency(row):
    base = bool(row.get('has_respiratory_critical')) and not row.get('neg_respiratory_critical', False)
    sev = row.get('severity_score', 0) >= 3
    fired = base and sev
    msg = "Respiratory emergency - risk of deterioration"
    return (fired, msg, _ev(row))

# ESI 3 Rules - Urgent but stable
def rule_ESI3_moderate_condition(row):
    base = bool(row.get('has_moderate_terms'))
    sev = row.get('severity_score', 0) >= 2
    fired = base and sev
    msg = "Urgent condition requiring prompt assessment (abdominal pain/fracture/infection)"
    return (fired, msg, _ev(row))

def rule_ESI3_prolonged_symptoms(row):
    dur = row.get('duration_any_days')
    dur_ok = pd.notna(dur) and dur >= 7
    sev = row.get('severity_score', 0) >= 2
    fired = dur_ok and sev and not bool(row.get('has_minor_terms'))
    msg = "Prolonged moderate symptoms requiring workup (≥7 days)"
    return (fired, msg, _ev(row))

def rule_ESI3_bilateral_involvement(row):
    bilateral = bool(row.get('has_bilateral'))
    sev = row.get('severity_score', 0) >= 2
    not_minor = not bool(row.get('has_minor_terms'))
    fired = bilateral and sev and not_minor
    msg = "Bilateral involvement with moderate severity"
    return (fired, msg, _ev(row))

# ESI 4 Rules - Non-urgent, stable
def rule_ESI4_stable_minor_injury(row):
    has_mod = bool(row.get('has_moderate_terms'))
    low_sev = row.get('severity_score', 0) <= 1
    chronic = bool(row.get('has_chronic'))
    fired = has_mod and low_sev and not bool(row.get('has_acute'))
    msg = "Stable minor injury/condition - can wait (minor laceration/burn)"
    return (fired, msg, _ev(row))

# ESI 5 Rules - Minor, no acute threat
def rule_ESI5_minor_complaint(row):
    fired = bool(row.get('has_minor_terms'))
    msg = "Minor complaint - no acute threat (medication refill/cold symptoms)"
    return (fired, msg, _ev(row))

def rule_ESI5_chronic_stable(row):
    chronic = bool(row.get('has_chronic'))
    low_sev = row.get('severity_score', 0) <= 1
    no_acute = not bool(row.get('has_acute'))
    fired = chronic and low_sev and no_acute and not bool(row.get('has_critical_finding'))
    msg = "Chronic stable condition - routine follow-up"
    return (fired, msg, _ev(row))

# Rule definitions with ESI levels
RULES = [
    # ESI 1 - Immediate
    ("ESI1_A", "ESI_1", rule_ESI1_arrest, "LIFE_THREAT"),
    ("ESI1_B", "ESI_1", rule_ESI1_critical_finding, "LIFE_THREAT"),
    
    # ESI 2 - Emergency
    ("ESI2_A", "ESI_2", rule_ESI2_cardiac_emergency, "HIGH_RISK"),
    ("ESI2_B", "ESI_2", rule_ESI2_neuro_emergency, "HIGH_RISK"),
    ("ESI2_C", "ESI_2", rule_ESI2_respiratory_emergency, "HIGH_RISK"),
    
    # ESI 3 - Urgent
    ("ESI3_A", "ESI_3", rule_ESI3_moderate_condition, "URGENT"),
    ("ESI3_B", "ESI_3", rule_ESI3_prolonged_symptoms, "URGENT"),
    ("ESI3_C", "ESI_3", rule_ESI3_bilateral_involvement, "URGENT"),
    
    # ESI 4 - Non-urgent
    ("ESI4_A", "ESI_4", rule_ESI4_stable_minor_injury, "NON_URGENT"),
    
    # ESI 5 - Minor
    ("ESI5_A", "ESI_5", rule_ESI5_minor_complaint, "MINOR"),
    ("ESI5_B", "ESI_5", rule_ESI5_chronic_stable, "MINOR"),
]

# =========================
# 4) SCORING & ARBITRATION
# =========================
def apply_rules(features_df: pd.DataFrame, top_k: int = 1):
    if features_df.empty:
        return pd.DataFrame(), pd.DataFrame()

    rows, scores, details = [], {}, {}
    for _, r in features_df.iterrows():
        rid = str(r['row_idx'])
        scores.setdefault(rid, 0.0)
        details.setdefault(rid, [])
        
        for rule_id, esi_level, fn, rule_type in RULES:
            fired, msg, ev = fn(r)
            if fired:
                weight = RULE_WEIGHTS.get(rule_type, 1.0) * r.get('category_weight', 1.0)
                scores[rid] += weight
                rows.append({
                    "row_idx": rid,
                    "rule_id": rule_id,
                    "esi_level": esi_level,
                    "esi_score": ESI_SCORE[esi_level],
                    "weight": weight,
                    "message": msg,
                    "evidence": ev
                })
                details[rid].append(f"{rule_id}:{msg}")

    if not rows:
        # No rules fired - default to ESI 5
        default_rows = []
        for rid in features_df['row_idx']:
            default_rows.append({
                "row_idx": str(rid),
                "rule_id": "DEFAULT",
                "esi_level": "ESI_5",
                "esi_score": ESI_SCORE["ESI_5"],
                "weight": 0.0,
                "message": "No specific alerts - default to ESI 5",
                "evidence": ""
            })
        alerts_df = pd.DataFrame(default_rows)
        final_df = pd.DataFrame([{"row_idx": str(rid), "score_sum": 0.0, 
                                  "final_esi_level": "ESI_5", "reasons": "Default"} 
                                 for rid in features_df['row_idx']])
    else:
        alerts_df = pd.DataFrame(rows).drop_duplicates(subset=["row_idx","rule_id","message"])
        
        # Final ESI arbitration by accumulated score
        arb = []
        for rid in features_df['row_idx']:
            rid = str(rid)
            tot = scores.get(rid, 0.0)
            
            if tot >= ESI_THRESHOLDS['ESI_1']:
                final_esi = "ESI_1"
            elif tot >= ESI_THRESHOLDS['ESI_2']:
                final_esi = "ESI_2"
            elif tot >= ESI_THRESHOLDS['ESI_3']:
                final_esi = "ESI_3"
            elif tot >= ESI_THRESHOLDS['ESI_4']:
                final_esi = "ESI_4"
            else:
                final_esi = "ESI_5"
            
            reasons = details.get(rid, ["No alerts"])
            arb.append({
                "row_idx": rid, 
                "score_sum": tot, 
                "final_esi_level": final_esi, 
                "reasons": "; ".join(reasons) if reasons else "No specific alerts"
            })
        final_df = pd.DataFrame(arb)

    # Choose top_k alerts per row (highest esi_score then weight)
    top = (alerts_df.sort_values(["row_idx","esi_score","weight"], ascending=[True, False, False])
                    .groupby("row_idx").head(top_k).reset_index(drop=True))

    return alerts_df.merge(final_df, on="row_idx", how="left"), top.merge(final_df, on="row_idx", how="left")

# =========================
# 5) DIAGNOSTICS
# =========================
def write_diagnostics(alerts_df: pd.DataFrame, final_top: pd.DataFrame, path: Path):
    if alerts_df.empty:
        return
    
    # Rule coverage
    diag_rules = (alerts_df.groupby(['rule_id','esi_level']).size()
                  .reset_index(name='n')
                  .sort_values('n', ascending=False))
    
    # ESI distribution
    if not final_top.empty and 'final_esi_level' in final_top.columns:
        esi_dist = final_top['final_esi_level'].value_counts().sort_index()
        diag_rules = pd.concat([
            diag_rules,
            pd.DataFrame([{"rule_id": "=== ESI DISTRIBUTION ===", "esi_level": "", "n": ""}]),
            esi_dist.reset_index().rename(columns={'index': 'rule_id', 'final_esi_level': 'n'})
        ])
    
    diag_rules.to_csv(path, index=False)
    
    # Print ESI distribution
    if not final_top.empty and 'final_esi_level' in final_top.columns:
        mix = final_top['final_esi_level'].value_counts(normalize=False).sort_index()
        print("\n[TRIAGE] Final ESI distribution:")
        print(mix.to_string())

# =========================
# 6) MAIN
# =========================
def main():
    entities_all = load_entities_from_csv(ALL_ENTITIES_CSV)
    if entities_all.empty:
        print("[TRIAGE] No entities loaded — nothing to do.")
        return

    features_df = build_features(entities_all)
    alerts_df, best_alert_per_row = apply_rules(features_df, top_k=1)

    if alerts_df.empty:
        print("\n[TRIAGE] No alerts generated.")
        return

    # Save outputs
    alerts_df.to_csv(TRIAGE_ALL_CSV, index=False)
    best_alert_per_row.to_csv(TRIAGE_TOP_CSV, index=False)
    print(f"\n[TRIAGE] Saved:\n - {TRIAGE_ALL_CSV}\n - {TRIAGE_TOP_CSV}")

    # ID length sanity check
    print("\n[TRIAGE] row_idx length distribution (first 10 bins):")
    print(best_alert_per_row["row_idx"].astype(str).str.len().value_counts().sort_index().head(10))

    print("\n=== Top alert per row (sample) ===")
    display_cols = ['row_idx', 'rule_id', 'final_esi_level', 'score_sum', 'message']
    print(best_alert_per_row[display_cols].head(15).to_string(index=False))

    # Write diagnostics
    write_diagnostics(alerts_df, best_alert_per_row, TRIAGE_DIAG_CSV)
    print(f"\n[TRIAGE] Diagnostics saved: {TRIAGE_DIAG_CSV}")

if __name__ == "__main__":
    main()


[TRIAGE] Loaded consolidated: /Users/anhche/Desktop/Project/Clinical_NLP/Clinical_NLP_Project/data/clean/all_entities.csv shape=(2473487, 6)
[TRIAGE][WARN] Found 9968 rows with very short row_idx (<=3 chars). Examples: <StringArray>
['958', '431', '558', '499', '390', '643', '629', '781', '637', '406']
Length: 10, dtype: string

[TRIAGE] Saved:
 - /Users/anhche/Desktop/Project/Clinical_NLP/Clinical_NLP_Project/artifacts/triage/triage_alerts_all.csv
 - /Users/anhche/Desktop/Project/Clinical_NLP/Clinical_NLP_Project/artifacts/triage/triage_alerts_top.csv

[TRIAGE] row_idx length distribution (first 10 bins):
row_idx
2        6
3       91
4      941
5     9961
6    11087
Name: count, dtype: int64

=== Top alert per row (sample) ===
row_idx rule_id final_esi_level  score_sum                                                                   message
 100017  ESI2_B           ESI_3       1.50               Neurological emergency - time-critical (stroke/seizure/AMS)
 100018  ESI2_A           