# Triage rules baseline

Prototype notebook for rule-based clinical NLP triage.

- Loads `data/lexicon_redflags.csv`
- Loads `data/notes_synthetic.csv`
- Applies rule-based scoring (token matching + minimal context filters)
- Exports `outputs/predictions.csv`
- Evaluates predictions vs. ground-truth `label`


In [None]:
import re
import pandas as pd

lexicon = pd.read_csv("../data/lexicon_redflags.csv")
notes = pd.read_csv("../data/notes_synthetic.csv")

# Basic sanitation
terms_raw = [str(t) for t in lexicon["term"].dropna().tolist()]

def tokenize(text: str) -> list[str]:
    """Lowercase + alphanumeric tokenization."""
    s = str(text).lower()
    return re.findall(r"[a-z0-9]+", s)

def tokenize_term(term: str) -> list[str]:
    return tokenize(term)

# Pre-tokenize lexicon terms (keep only non-empty)
TERMS = []
for t in terms_raw:
    tt = tokenize_term(t)
    if tt:
        TERMS.append((t, tt))  # (original, tokens)

# Minimal context cues (EN only for now)
NEGATION_CUES = {
    "no", "not", "denies", "denied", "without", "negative"
}

# Multi-token cues handled by matching token sequences
NEGATION_PHRASES = [
    ["negative", "for"],
    ["free", "of"],
    ["rule", "out"],
]

HISTORICAL_CUES = {
    "history", "previous", "prior", "past", "resolved", "remote"
}

HISTORICAL_PHRASES = [
    ["years", "ago"],
    ["last", "year"],
    ["in", "the", "past"],
    ["hx", "of"],
    ["history", "of"],
]

def find_subsequence_positions(tokens: list[str], sub: list[str]) -> list[int]:
    """Return start indices where sub occurs in tokens."""
    if not sub or len(sub) > len(tokens):
        return []
    hits = []
    L = len(sub)
    for i in range(0, len(tokens) - L + 1):
        if tokens[i:i+L] == sub:
            hits.append(i)
    return hits

def window_tokens(tokens: list[str], start: int, end: int) -> list[str]:
    start = max(0, start)
    end = min(len(tokens), end)
    return tokens[start:end]

def has_phrase(window: list[str], phrases: list[list[str]]) -> bool:
    for ph in phrases:
        if find_subsequence_positions(window, ph):
            return True
    return False

def is_negated(tokens: list[str], term_start: int, term_len: int, window: int = 5) -> bool:
    """Check for negation cues immediately before the term (simple window)."""
    w = window_tokens(tokens, term_start - window, term_start)
    if any(t in NEGATION_CUES for t in w):
        return True
    if has_phrase(w, NEGATION_PHRASES):
        return True
    return False

def is_historical(tokens: list[str], term_start: int, term_len: int, window: int = 8) -> bool:
    """Check for historical cues in a wider window before the term."""
    w = window_tokens(tokens, term_start - window, term_start)
    if any(t in HISTORICAL_CUES for t in w):
        return True
    if has_phrase(w, HISTORICAL_PHRASES):
        return True
    return False

def count_hits_with_filters(text: str) -> dict:
    """Count lexicon hits using token matching, excluding negated/historical contexts."""
    tokens = tokenize(text)

    raw_hits = 0
    filtered_hits = 0
    negated_hits = 0
    historical_hits = 0

    # For each term, count at most once per note (keeps baseline conservative)
    for original, tt in TERMS:
        positions = find_subsequence_positions(tokens, tt)
        if not positions:
            continue

        # If term appears multiple times, we treat any non-negated & non-historical occurrence as a valid hit
        raw_hits += 1
        term_len = len(tt)

        valid = False
        any_neg = False
        any_hist = False
        for pos in positions:
            neg = is_negated(tokens, pos, term_len)
            hist = is_historical(tokens, pos, term_len)
            if neg:
                any_neg = True
            if hist:
                any_hist = True
            if (not neg) and (not hist):
                valid = True
                break

        if valid:
            filtered_hits += 1
        else:
            if any_neg:
                negated_hits += 1
            if any_hist:
                historical_hits += 1

    return {
        "hits_raw": raw_hits,
        "hits_filtered": filtered_hits,
        "hits_negated": negated_hits,
        "hits_historical": historical_hits,
    }

def predict_label_from_hits(hits: int) -> str:
    if hits >= 2:
        return "high"
    elif hits == 1:
        return "intermediate"
    else:
        return "low"

stats = notes["text"].apply(count_hits_with_filters).apply(pd.Series)
notes = pd.concat([notes, stats], axis=1)
notes["predicted_label"] = notes["hits_filtered"].apply(predict_label_from_hits)

notes[["id", "text", "entity", "label", "hits_raw", "hits_filtered", "hits_negated", "hits_historical", "predicted_label"]].to_csv(
    "../outputs/predictions.csv", index=False
)

print("Saved ../outputs/predictions.csv")
print("\nLabel distribution (predicted):")
print(notes["predicted_label"].value_counts(dropna=False).to_string())


## Quick evaluation (baseline)

This evaluates the baseline against the synthetic ground truth `label`.

- Overall accuracy
- Confusion matrix (all entities)
- Accuracy by entity
- A few mismatches to inspect (with hit diagnostics)


In [None]:
# Guardrails: only run evaluation if ground-truth label exists
if "label" not in notes.columns:
    raise ValueError("notes_synthetic.csv is missing required column: 'label'")

# Overall accuracy
acc = (notes["predicted_label"] == notes["label"]).mean()
print(f"Overall accuracy: {acc:.3f} ({int((notes['predicted_label'] == notes['label']).sum())}/{len(notes)})")

# Confusion matrix
print("\nConfusion matrix (label x predicted_label):")
cm = pd.crosstab(
    notes["label"],
    notes["predicted_label"],
    rownames=["label"],
    colnames=["predicted_label"],
    dropna=False,
)
print(cm.to_string())

# Accuracy by entity
print("\nAccuracy by entity:")
by_entity = notes.groupby("entity").apply(lambda df: (df["predicted_label"] == df["label"]).mean())
by_entity = by_entity.sort_values(ascending=False)
print(by_entity.to_string())

# Show a few mismatches for inspection
cols = ["id", "entity", "text", "label", "predicted_label", "hits_raw", "hits_filtered", "hits_negated", "hits_historical"]
mismatches = notes.loc[notes["predicted_label"] != notes["label"], cols]
print("\nSample mismatches (first 20):")
print(mismatches.head(20).to_string(index=False))

# Optional: quick look at cases where raw hits were reduced by filters
reduced = notes.loc[notes["hits_raw"] > notes["hits_filtered"], cols]
print("\nSample reduced-by-filters (first 20):")
print(reduced.head(20).to_string(index=False))
