# Triage rules baseline

Prototype notebook for rule-based clinical NLP triage.

- Loads `data/lexicon_redflags.csv`
- Loads `data/notes_synthetic.csv`
- Applies simple rule-based scoring (substring baseline)
- Adds minimal negation handling (heuristic)
- Exports `outputs/predictions.csv`
- Evaluates predictions vs. ground-truth `label`


In [None]:
import re
import pandas as pd

lexicon = pd.read_csv("../data/lexicon_redflags.csv")
notes = pd.read_csv("../data/notes_synthetic.csv")

# Normalize lexicon terms (keep as phrases)
terms = (
    lexicon["term"].dropna().astype(str).str.strip().str.lower().tolist()
)

# Minimal negation lexicon (heuristic)
# Goal: avoid counting obvious "no/denies/without" near a term
NEGATIONS = {
    "no",
    "not",
    "denies",
    "denied",
    "without",
    "negative",
    "neg",
    "never",
}

TOKEN_RE = re.compile(r"[a-z0-9']+")

def tokenize(text: str) -> list[str]:
    return TOKEN_RE.findall(str(text).lower())

def phrase_in_tokens_at(tokens: list[str], phrase_tokens: list[str]) -> list[int]:
    """Return start indices where phrase_tokens match tokens contiguously."""
    if not phrase_tokens or not tokens:
        return []
    n = len(phrase_tokens)
    hits = []
    for i in range(0, len(tokens) - n + 1):
        if tokens[i : i + n] == phrase_tokens:
            hits.append(i)
    return hits

def is_negated(tokens: list[str], start_idx: int, window: int = 3) -> bool:
    """Heuristic: if a negation token appears within `window` tokens BEFORE the phrase start."""
    left = max(0, start_idx - window)
    context = tokens[left:start_idx]
    return any(t in NEGATIONS for t in context)

def count_hits_with_negation(text: str, terms: list[str]) -> tuple[int, int]:
    """Return (positive_hits, negated_hits)."""
    tokens = tokenize(text)
    pos_hits = 0
    neg_hits = 0

    for term in terms:
        if not term:
            continue
        phrase_tokens = tokenize(term)
        # Fallback: if tokenization empties it, skip
        if not phrase_tokens:
            continue

        # Find all occurrences of the term as a phrase
        starts = phrase_in_tokens_at(tokens, phrase_tokens)
        if not starts:
            continue

        # Count each term at most once per note (baseline behavior)
        # If any occurrence is non-negated -> count as positive
        any_pos = False
        any_neg = False
        for s in starts:
            if is_negated(tokens, s, window=3):
                any_neg = True
            else:
                any_pos = True

        if any_pos:
            pos_hits += 1
        elif any_neg:
            neg_hits += 1

    return pos_hits, neg_hits

def predict_label_from_hits(hits: int) -> str:
    if hits >= 2:
        return "high"
    elif hits == 1:
        return "intermediate"
    else:
        return "low"

# Apply scoring
hit_counts = notes["text"].apply(lambda t: count_hits_with_negation(t, terms))
notes["hits_count"] = hit_counts.apply(lambda x: x[0])
notes["negated_hits_count"] = hit_counts.apply(lambda x: x[1])
notes["predicted_label"] = notes["hits_count"].apply(predict_label_from_hits)

# Export
notes[["id", "text", "entity", "hits_count", "negated_hits_count", "predicted_label"]].to_csv(
    "../outputs/predictions.csv", index=False
)

print("Saved ../outputs/predictions.csv")
print("\nNegation heuristic: ignores terms when a negation token appears within 3 tokens before the phrase.")
print("Total positive hits:", int(notes["hits_count"].sum()))
print("Total negated hits:", int(notes["negated_hits_count"].sum()))


## Quick evaluation (baseline)

This evaluates the simple baseline against the synthetic ground truth `label`.

- Overall accuracy
- Confusion matrix (all entities)
- Accuracy by entity
- A few mismatches to inspect


In [None]:
# Guardrails: only run evaluation if ground-truth label exists
if "label" not in notes.columns:
    raise ValueError("notes_synthetic.csv is missing required column: 'label'")

# Overall accuracy
acc = (notes["predicted_label"] == notes["label"]).mean()
print(f"Overall accuracy: {acc:.3f} ({int((notes['predicted_label'] == notes['label']).sum())}/{len(notes)})")

# Confusion matrix
print("\nConfusion matrix (label x predicted_label):")
cm = pd.crosstab(notes["label"], notes["predicted_label"], rownames=["label"], colnames=["predicted_label"], dropna=False)
print(cm.to_string())

# Accuracy by entity
print("\nAccuracy by entity:")
by_entity = notes.groupby("entity").apply(lambda df: (df["predicted_label"] == df["label"]).mean())
by_entity = by_entity.sort_values(ascending=False)
print(by_entity.to_string())

# Show a few mismatches for inspection
mismatches = notes.loc[notes["predicted_label"] != notes["label"], ["id", "entity", "text", "label", "predicted_label", "hits_count", "negated_hits_count"]]
print("\nSample mismatches (first 15):")
print(mismatches.head(15).to_string(index=False))
