# Triage rules baseline

Prototype notebook for rule-based clinical NLP triage.

- Loads `data/lexicon_redflags.csv`
- Loads `data/notes_synthetic.csv`
- Applies simple rule-based scoring
- Exports `outputs/predictions.csv`
- Evaluates predictions vs. ground-truth `label`


In [None]:
import pandas as pd

lexicon = pd.read_csv("../data/lexicon_redflags.csv")
notes = pd.read_csv("../data/notes_synthetic.csv")

terms = [str(t).lower() for t in lexicon["term"].tolist()]

def predict_label(text):
    text = str(text).lower()
    hits = sum(1 for t in terms if t and t in text)
    if hits >= 2:
        return "high"
    elif hits == 1:
        return "intermediate"
    else:
        return "low"

notes["predicted_label"] = notes["text"].apply(predict_label)

notes[["id", "text", "entity", "predicted_label"]].to_csv(
    "../outputs/predictions.csv", index=False
)

print("Saved outputs/predictions.csv")


## Quick evaluation (baseline)

This evaluates the simple substring-matching baseline against the synthetic ground truth `label`.

- Overall accuracy
- Confusion matrix (all entities)
- Accuracy by entity
- A few mismatches to inspect


In [None]:
# Guardrails: only run evaluation if ground-truth label exists
if "label" not in notes.columns:
    raise ValueError("notes_synthetic.csv is missing required column: 'label'")

# Overall accuracy
acc = (notes["predicted_label"] == notes["label"]).mean()
print(f"Overall accuracy: {acc:.3f} ({int((notes['predicted_label'] == notes['label']).sum())}/{len(notes)})")

# Confusion matrix
print("\nConfusion matrix (label x predicted_label):")
cm = pd.crosstab(notes["label"], notes["predicted_label"], rownames=["label"], colnames=["predicted_label"], dropna=False)
print(cm.to_string())

# Accuracy by entity
print("\nAccuracy by entity:")
by_entity = notes.groupby("entity").apply(lambda df: (df["predicted_label"] == df["label"]).mean())
by_entity = by_entity.sort_values(ascending=False)
print(by_entity.to_string())

# Show a few mismatches for inspection
mismatches = notes.loc[notes["predicted_label"] != notes["label"], ["id", "entity", "text", "label", "predicted_label"]]
print("\nSample mismatches (first 15):")
print(mismatches.head(15).to_string(index=False))
