### **Llama3_8B (N = 58166, whole)**

In [None]:
# -----------------------------
# Prompt + parsing
# -----------------------------
SYSTEM_INSTRUCTION = """You are a conservative clinical text rater.
Do not diagnose; only label explicit depressive symptoms in the patient's words."""

USER_TEMPLATE = """Task: Label depressive symptoms (1) vs not present (0) from the patient's message.
Use only explicit evidence from the text. If uncertain, label 0.

Depressive symptoms (label 1 only if explicitly present):
- persistent sadness/depressed mood
- hopelessness/worthlessness/guilt
- loss of interest/pleasure (anhedonia)
- suicidal ideation/self-harm (always 1)
- explicit statements of inability to function due to mood (not due to pain/illness)

Not depression (label 0):
- appointment scheduling or cancellations
- physical illness/pain without mood language
- mild frustration without mood symptoms

IMPORTANT: The following are examples that MUST be labeled 0:
Example A message: "Just took my temperature. 98.9. Bummer! I guess I'll lie low for today."
Explanation: "bummer" + resting does NOT equal depression.
Example B message: "I can’t come on 4/8, so put me down for 4/20."
Explanation: rescheduling does NOT indicate low motivation.
Example C message: "Doctor is aware that I am in acute pain ... MRI ... urgent care?"
Explanation: acute pain/medical urgency is physical distress, not depression.

Output format (exactly):
Classification: <0 or 1>
Evidence: <quote exact phrase(s) supporting the label; if 0 write "none">
Reason: <brief explanation tied to Evidence>
Message:
{text}"""

# Tune for your workload
BATCH_SIZE = 16
MAX_NEW_TOKENS = 45
TEMPERATURE = 0.0
TOP_P = 0.9

# Prompt truncation cap (match llama_local_batch_fn default unless you want different)
MAX_PROMPT_LEN = 2048

LOG_EVERY_BATCHES = 20
SAVE_EVERY_BATCHES = 20



@torch.no_grad()
def llama_local_batch_fn(
    texts: List[str],
    max_new_tokens: int = 64,
    temperature: float = 0.0,
    top_p: float = 0.9,
    max_length: int = 2048,
    return_raw: bool = False,
) -> Tuple[List[int], List[str], Optional[List[str]]]:
    messages_list = []
    for t in texts:
        user_prompt = USER_TEMPLATE.format(text=t)
        messages_list.append([
            {"role": "system", "content": SYSTEM_INSTRUCTION},
            {"role": "user", "content": user_prompt},
        ])

    prompts = [
        tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
        for m in messages_list
    ]

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    ).to(model.device)

    do_sample = temperature > 0.0

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
    )

    prompt_lens = inputs["attention_mask"].sum(dim=1).tolist()

    classifications, reasons = [], []
    completions = [] if return_raw else None

    for i in range(out.shape[0]):
        gen_ids = out[i, prompt_lens[i]:]
        completion = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
        c, r = _parse_classification_and_reason(completion)
        classifications.append(int(c) if c in (0, 1) else 0)
        reasons.append((r or "").strip())
        if return_raw:
            completions.append(completion)

    return classifications, reasons, completions



### **Accuracy computation**

In [None]:
def compute_metrics_at_threshold(df, threshold, prevalence=0.25, eps=0.5):
    required = {"total_messages", "positive_messages", "label"}
    missing = required - set(df.columns)
    if missing:
        raise KeyError(f"Missing required columns: {sorted(missing)}")

    eligible = df.loc[df["total_messages"] >= threshold].copy()
    n_eligible = int(len(eligible))

    if n_eligible == 0:
        return {
            "threshold": threshold,
            "n_eligible": 0,
            "n_flagged": 0,
            "tp": 0, "fp": 0, "fn": 0, "tn": 0,
            "sensitivity": np.nan,
            "specificity": np.nan,
            "ppv_adj": np.nan,
            "npv_adj": np.nan,
            "odds_ratio": np.nan,
            "or_ci_low": np.nan,
            "or_ci_high": np.nan,
        }

    y_true = eligible["label"].astype(int).clip(0, 1)
    y_pred = (eligible["positive_messages"] >= threshold).astype(int)

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    sensitivity = tp / (tp + fn) if (tp + fn) else np.nan
    specificity = tn / (tn + fp) if (tn + fp) else np.nan

    denom_ppv = sensitivity * prevalence + (1 - specificity) * (1 - prevalence)
    ppv_adj = (sensitivity * prevalence) / denom_ppv if denom_ppv else np.nan

    denom_npv = (1 - sensitivity) * prevalence + specificity * (1 - prevalence)
    npv_adj = (specificity * (1 - prevalence)) / denom_npv if denom_npv else np.nan

    tp_c, fp_c, fn_c, tn_c = tp + eps, fp + eps, fn + eps, tn + eps
    odds_ratio = (tp_c * tn_c) / (fp_c * fn_c)

    se_log_or = math.sqrt(1/tp_c + 1/fp_c + 1/fn_c + 1/tn_c)
    or_ci_low = math.exp(math.log(odds_ratio) - 1.96 * se_log_or)
    or_ci_high = math.exp(math.log(odds_ratio) + 1.96 * se_log_or)

    return {
        "threshold": int(threshold),
        "n_eligible": n_eligible,
        "n_flagged": int(y_pred.sum()),
        "tp": int(tp), "fp": int(fp), "fn": int(fn), "tn": int(tn),
        "sensitivity": sensitivity,
        "specificity": specificity,
        "ppv_adj": ppv_adj,
        "npv_adj": npv_adj,
        "odds_ratio": odds_ratio,
        "or_ci_low": or_ci_low,
        "or_ci_high": or_ci_high,
    }


In [None]:
max_t = int(patient_df["total_messages"].max())

results = [
    compute_metrics_at_threshold(patient_df, t, prevalence=0.25)
    for t in range(1, max_t + 1)
]

perf_df = pd.DataFrame(results)


### **Staged screening analysis**

In [None]:
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any

def compute_stage2_lead_time(df_msg: pd.DataFrame, t2: int, id_col: str = "a_id") -> pd.DataFrame:
    d = df_msg.copy()
    d["created_time"] = pd.to_datetime(d["created_time"])
    d["dep_start"] = pd.to_datetime(d["dep_start"])

    # Restrict to cases
    d = d[d["label"] == 1].copy()

    d["flag_msg"] = (d["predictions"] == 1).astype(int)
    d = d.sort_values([id_col, "created_time"])
    d["cum_pos"] = d.groupby(id_col)["flag_msg"].cumsum()

    reach = (
        d.loc[d["cum_pos"] >= t2]
         .groupby(id_col, as_index=False)
         .first()[[id_col, "created_time"]]
         .rename(columns={"created_time": "reach_date_t2"})
    )

    dx = (
        d.groupby(id_col, as_index=False)["dep_start"]
         .first()
         .rename(columns={"dep_start": "dx_date"})
    )

    out = dx.merge(reach, on=id_col, how="left")
    out["days_earlier_t2"] = (out["dx_date"] - out["reach_date_t2"]).dt.days
    return out


def staged_metrics_eligible(
    person_df: pd.DataFrame,
    t1: int = 1,
    t2: int = 30,
    df_msg_for_lead: Optional[pd.DataFrame] = None,
    id_col: str = "a_id",
    total_msg_col: str = "total_messages",
    require_pre_dx_msgs_for_lead: bool = False,
) -> Dict[str, Any]:

    dfp =person_df.copy()

    for col in [id_col, total_msg_col, "positive_messages", "label"]:
        if col not in dfp.columns:
            raise ValueError(f"Missing required column in person_df: '{col}'")

    eligible = dfp[dfp[total_msg_col] >= t2].copy()

    if len(eligible) == 0:
        return {
            "t1": t1, "t2": t2,
            "n_eligible": 0,
            "n_monitored": 0, "n_escalated": 0,
            "n_monitored_cases": 0, "n_monitored_controls": 0,
            "n_escalated_cases": 0, "n_escalated_controls": 0,
            "escalation_overall": np.nan,
            "escalation_cases": np.nan,
            "escalation_controls": np.nan,
            "ppv_stage2": np.nan,
            "n_cases_with_lead_time_t2": 0,
            "median_days_earlier_t2": np.nan,
        }

    eligible["stage1"] = (eligible["positive_messages"] >= t1).astype(int)
    monitored = eligible[eligible["stage1"] == 1].copy()

    monitored["stage2"] = (monitored["positive_messages"] >= t2).astype(int)
    escalated = monitored[monitored["stage2"] == 1].copy()

    n_monitored = len(monitored)
    n_escalated = len(escalated)

    escalation_overall = n_escalated / n_monitored if n_monitored else np.nan

    mon_cases = monitored[monitored["label"] == 1]
    mon_controls = monitored[monitored["label"] == 0]
    esc_cases = escalated[escalated["label"] == 1]
    esc_controls = escalated[escalated["label"] == 0]

    escalation_cases = len(esc_cases) / len(mon_cases) if len(mon_cases) else np.nan
    escalation_controls = len(esc_controls) / len(mon_controls) if len(mon_controls) else np.nan

    ppv_stage2 = escalated["label"].mean() if n_escalated else np.nan

    median_days_earlier_t2 = np.nan
    n_cases_reached_t2_pre_dx = 0

    if df_msg_for_lead is not None and len(esc_cases):
        for col in [id_col, "created_time", "dep_start", "label", "predictions"]:
            if col not in df_msg_for_lead.columns:
                raise ValueError(f"Missing required column in df_msg_for_lead: '{col}'")

        lead_tbl = compute_stage2_lead_time(df_msg_for_lead, t2=t2, id_col=id_col)

        esc_case_ids = set(esc_cases[id_col].unique())
        lead_tbl = lead_tbl[lead_tbl[id_col].isin(esc_case_ids)].copy()

        lead_tbl = lead_tbl.dropna(subset=["days_earlier_t2"])
        lead_tbl = lead_tbl[lead_tbl["days_earlier_t2"] >= 0]

        if require_pre_dx_msgs_for_lead:
            d = df_msg_for_lead.copy()
            d["created_time"] = pd.to_datetime(d["created_time"])
            d["dep_start"] = pd.to_datetime(d["dep_start"])
            d = d[d["label"] == 1].copy()

            pre_dx_counts = (
                d[d["created_time"] <= d["dep_start"]]
                .groupby(id_col)
                .size()
                .rename("pre_dx_total_messages")
                .reset_index()
            )

            lead_tbl = lead_tbl.merge(pre_dx_counts, on=id_col, how="left")
            lead_tbl["pre_dx_total_messages"] = lead_tbl["pre_dx_total_messages"].fillna(0).astype(int)
            lead_tbl = lead_tbl[lead_tbl["pre_dx_total_messages"] >= t2]

        n_cases_reached_t2_pre_dx = len(lead_tbl)
        if n_cases_reached_t2_pre_dx:
            median_days_earlier_t2 = float(lead_tbl["days_earlier_t2"].median())

    return {
        "t1": t1,
        "t2": t2,
        "n_eligible": int(len(eligible)),
        "ppv_stage2": float(ppv_stage2) if ppv_stage2 == ppv_stage2 else np.nan,
        "n_cases_with_lead_time_t2": int(n_cases_reached_t2_pre_dx),
        "median_days_earlier_t2": median_days_earlier_t2,
    }


### Message-level (message-volume adjusted = **exposure-adjusted) performmance**

In [None]:
from statsmodels.stats.proportion import proportion_confint
import numpy as np

# Counts
TP = ((comb["predictions"] == 1) & (comb["label"] == 1)).sum()
FN = ((comb["predictions"] == 0) & (comb["label"] == 1)).sum()
FP = ((comb["predictions"] == 1) & (comb["label"] == 0)).sum()
TN = ((comb["predictions"] == 0) & (comb["label"] == 0)).sum()

# Point estimates
sensitivity = TP / (TP + FN) if (TP + FN) > 0 else np.nan
specificity = TN / (TN + FP) if (TN + FP) > 0 else np.nan
PPV = TP / (TP + FP) if (TP + FP) > 0 else np.nan
NPV = TN / (TN + FN) if (TN + FN) > 0 else np.nan

# Wilson 95% CI for sensitivity and specificity
sens_ci_low, sens_ci_high = (
    proportion_confint(TP, TP + FN, alpha=0.05, method="wilson")
    if (TP + FN) > 0 else (np.nan, np.nan)
)
spec_ci_low, spec_ci_high = (
    proportion_confint(TN, TN + FP, alpha=0.05, method="wilson")
    if (TN + FP) > 0 else (np.nan, np.nan)
)

# Odds ratio + Woolf CI (with continuity correction if needed)
cells = np.array([TP, FP, FN, TN], dtype=float)
use_cc = np.any(cells == 0)

TPc, FPc, FNc, TNc = (cells + 0.5) if use_cc else cells  # Haldane–Anscombe correction

odds_ratio = (TPc * TNc) / (FPc * FNc)
se_log_or = np.sqrt(1/TPc + 1/TNc + 1/FPc + 1/FNc)
log_or = np.log(odds_ratio)
ci_lower = np.exp(log_or - 1.96 * se_log_or)
ci_upper = np.exp(log_or + 1.96 * se_log_or)

print(f"Sensitivity: {sensitivity:.4f} (95% CI {sens_ci_low:.4f}–{sens_ci_high:.4f})")
print(f"Specificity: {specificity:.4f} (95% CI {spec_ci_low:.4f}–{spec_ci_high:.4f})")
print(f"PPV: {PPV:.4f}")
print(f"NPV: {NPV:.4f}")
print(f"Odds Ratio: {odds_ratio:.4f} (95% CI {ci_lower:.4f}–{ci_upper:.4f})"
      + (" [continuity-corrected]" if use_cc else ""))


### **Performance by time window**

In [None]:
import pandas as pd

def days_earlier_by_threshold(df: pd.DataFrame, t: int, id_col: str = "a_id") -> pd.DataFrame:
    df = df.copy()
    df["created_time"] = pd.to_datetime(df["created_time"])
    df["dep_start"] = pd.to_datetime(df["dep_start"])

    cases = df[df["label"] == 1].copy()
    cases["flag_msg"] = (cases["predictions"] == 1).astype(int)

    cases = cases.sort_values([id_col, "created_time"])
    cases["cum_pos"] = cases.groupby(id_col)["flag_msg"].cumsum()

    reach = (
        cases.loc[cases["cum_pos"] >= t]
            .groupby(id_col, as_index=False)
            .first()[[id_col, "created_time"]]
            .rename(columns={"created_time": "reach_date"})
    )

    dx = (
        cases.groupby(id_col, as_index=False)["dep_start"]
            .first()
            .rename(columns={"dep_start": "dx_date"})
    )

    out = dx.merge(reach, on=id_col, how="left")
    out["days_earlier"] = (out["dx_date"] - out["reach_date"]).dt.days
    return out

early_t5 = days_earlier_by_threshold(df, t=5)

flagged_t5 = early_t5.dropna(subset=["days_earlier"]).copy()
flagged_t5 = flagged_t5[flagged_t5["days_earlier"] >= 0]  # reached before diagnosis
summary = flagged_t5["days_earlier"].describe()
print(summary)
