# Figures for the ARES paper

Requires additional dependencies, run `pip install -e .[jupyter]`

In order to run the notebook, you need to have the MIMIC-IV 2.2 dataset with the ED extension.
And run the MEDS extraction pipeline, refer to README for the instructions how to do it. Afterwards, adjust the below path to match your setup.

In [None]:
from ethos.constants import PROJECT_ROOT

mimic_dir = PROJECT_ROOT / "data/mimic-2.2"
mimic_meds_dir = PROJECT_ROOT / "data/mimic-2.2-meds-ed"

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

from enum import StrEnum

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from matplotlib.offsetbox import AnchoredText
from pylatex import NoEscape, Table, Tabular, escape_latex

from ethos.constants import SpecialToken as ST
from ethos.inference.constants import Reason, Task
from ethos.metrics import preprocess_inference_results


def make_bold(s):
    return NoEscape(r"\textbf{" + escape_latex(s) + "}")


split_titles = {
    "train": "Train/Validation",
    "test": "Test",
    "total": "Total",
}

n_bootstraps = 10

sns.set_theme(context="paper", style="white")

# Colors
black_color = "#404040ff"
gray_color = "#b2b2b2ff"
orange_color = "#ff8533ff"
font_size = 18

# Matplotlib settings
plt.rcParams["axes.labelcolor"] = black_color
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["axes.titlecolor"] = black_color
plt.rcParams["axes.titlesize"] = font_size
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["figure.labelsize"] = font_size
plt.rcParams["figure.labelweight"] = "bold"
plt.rcParams["font.family"] = "Roboto"  # has to be installed on the system
plt.rcParams["font.weight"] = "bold"
plt.rcParams["text.color"] = black_color
plt.rcParams["xtick.color"] = black_color
plt.rcParams["xtick.labelsize"] = font_size
plt.rcParams["ytick.color"] = black_color
plt.rcParams["ytick.labelsize"] = font_size

## Patient Demographics

In [None]:
# Requires original MIMIC-IV 2.2 dataset that we do not provide
# if you have them in the csv.gz format, use `pl.read_csv`, and change format of the files accordingly
patients_df = pl.read_parquet(mimic_dir / "hosp/patients.parquet")
admissions_df = pl.read_parquet(mimic_dir / "hosp/admissions.parquet")

# This file is automatically created by running the MEDS extraction pipeline, refer to README.
# Optionally, it can be downloaded at https://github.com/ipolharvard/mimic4ed-benchmark/blob/main/scripts/data/subject_splits.parquet
subject_split_df = pl.read_parquet(mimic_meds_dir / "metadata/subject_splits.parquet")


def compute_split_counts(df: pl.DataFrame) -> pl.DataFrame:
    splits = list(split_titles.keys())
    return (
        df.group_by("split", "code")
        .agg(pl.count("code").alias("count"))
        .pivot("split", index="code")
        .with_columns(total=pl.sum_horizontal(*splits[:2]))
        .select("code", *splits)
        .sort(splits[-1], descending=True)
    )


def add_percentages(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        pl.col(col).map_elements(lambda s: f"{s:,}", return_dtype=pl.String)
        + (
            (pl.col(col) / pl.sum(col)).map_elements(
                lambda s: f" ({s * 100:.1f})", return_dtype=pl.String
            )
        )
        for col in split_titles.keys()
    )

In [None]:
df_patient_num = patients_df.join(subject_split_df, on="subject_id").with_columns(
    code=pl.lit("Patient Number")
)
df_patient_num = compute_split_counts(df_patient_num)
df_patient_num = df_patient_num.with_columns(
    pl.exclude("code").map_elements(lambda s: f"{s:,}", return_dtype=pl.String)
)
df_patient_num

In [None]:
df_age = patients_df.join(subject_split_df, on="subject_id").with_columns(
    code=pl.lit("Mean Age (std)")
)
df_age = (
    pl.concat(
        (
            df_age.group_by("split", maintain_order=True).agg(
                pl.first("code"),
                pl.mean("anchor_age").alias("mean"),
                pl.std("anchor_age").alias("std"),
            ),
            df_age.group_by(split=pl.lit("total")).agg(
                pl.first("code"),
                pl.mean("anchor_age").alias("mean"),
                pl.std("anchor_age").alias("std"),
            ),
        )
    )
    .select(
        "split",
        "code",
        value=pl.col("mean").map_elements(lambda s: f"{s:.1f}", return_dtype=pl.String)
        + pl.col("std").map_elements(lambda s: f" ({s:.1f})", return_dtype=pl.String),
    )
    .pivot("split", index="code")
)
df_age

In [None]:
df_gender_full = (
    patients_df.select("subject_id", "gender")
    .join(subject_split_df, on="subject_id")
    .rename({"gender": "code"})
)
df_gender_full = df_gender_full.with_columns(
    pl.col("code").replace_strict({"F": "Female", "M": "Male"})
)
df_gender = compute_split_counts(df_gender_full)
df_gender = add_percentages(df_gender)
df_gender

In [None]:
from ethos.tokenize.mimic import DemographicData

df_race_full = (
    admissions_df.join(subject_split_df, on="subject_id", how="right")
    .rename({"race": "text_value"})
    .select("subject_id", pl.col("text_value").fill_null("UNKNOWN"), "split", code=pl.lit("RACE"))
)
df_race_full = DemographicData.process_race(df_race_full).drop("text_value")
df_race_full = df_race_full.with_columns(pl.col("code").str.slice(len("RACE//")).str.to_titlecase())
df_race = compute_split_counts(df_race_full)
df_race = add_percentages(df_race)
df_race

In [None]:
df_marital = (
    admissions_df.join(subject_split_df, on="subject_id", how="right")
    .rename({"marital_status": "code"})
    .sort("subject_id", "admittime")
    .group_by("subject_id", maintain_order=True)
    .agg(pl.first("code", "split"))
    .with_columns(pl.col("code").fill_null("UNKNOWN"))
)
df_marital = compute_split_counts(df_marital)
df_marital = df_marital.with_columns(pl.col("code").str.to_titlecase())
df_marital = add_percentages(df_marital)
df_marital

In [None]:
dem_combined_df = pl.concat(
    [
        df_patient_num.with_columns(group=pl.lit("Patient Number"), code=None).cast(pl.String),
        df_age.with_columns(group=pl.lit("Mean Age (std.)"), code=None),
        df_gender.with_columns(group=pl.lit("Gender (%)")),
        df_race.with_columns(group=pl.lit("Race (%)")),
        df_marital.with_columns(group=pl.lit("Marital Status (%)")),
    ]
).select("group", "code", *split_titles.keys())
dem_combined_df

In [None]:
table = Table()

tabular = Tabular("l" + "r" * len(split_titles))
tabular.append(NoEscape(r"\toprule"))
tabular.add_row("", *[make_bold(t) for t in split_titles.values()])
tabular.append(NoEscape(r"\toprule"))

last_group = ""
for row in dem_combined_df.rows():
    group, subgroup, values = row[0], row[1], row[2:]

    if subgroup is None:
        first_cell = make_bold(group.title())
    else:
        if last_group != group:
            tabular.append(NoEscape(r"\midrule"))
            tabular.add_row(make_bold(group.title()), *[""] * len(values))
            last_group = group
        first_cell = NoEscape(r"\hspace{1em} " + subgroup)

    tabular.add_row(first_cell, *values)

tabular.append(NoEscape(r"\bottomrule"))
table.append(NoEscape(r"\centering"))

table.add_caption(
    NoEscape(
        r"\textbf{Demographic characteristics of the dataset analyzed in this study.}"
        " The table summarizes key demographic attributes of the dataset, stratified into "
        "Train/Validation, Test, and Total splits. Patient numbers, mean age (with standard deviation),"
        " and distribution across gender, race, and marital status are shown, with percentages "
        "provided in parentheses. The data highlights the representation of each subgroup within "
        "the splits, providing context for the population characteristics in the dataset."
    )
)
table.append(NoEscape(r"\label{tab:population-demographic}"))
table.append(tabular)

print(table.dumps())

## Preparation of ETHOS Results

In [None]:
class PaperTask(StrEnum):
    # tasks to showcase ARES
    HOSPITAL_MORTALITY = Task.HOSPITAL_MORTALITY
    ICU_ADMISSION = Task.ICU_ADMISSION
    PROLONGED_STAY = "prolonged_stay"
    COMPOSITE = "composite"
    # tasks from the ED benchmark paper
    ED_HOSPITALIZATION = Task.ED_HOSPITALIZATION
    ED_CRITICAL_OUTCOME = Task.ED_CRITICAL_OUTCOME
    ED_REPRESENTATION = Task.ED_REPRESENTATION


task_titles = {
    PaperTask.HOSPITAL_MORTALITY: "Hospital Mortality",
    PaperTask.ICU_ADMISSION: "ICU Admission",
    PaperTask.PROLONGED_STAY: "Prolonged Stay",
    PaperTask.COMPOSITE: "Composite (HM+IA+PS)",
    PaperTask.ED_HOSPITALIZATION: "Hospitalization At Triage",
    PaperTask.ED_CRITICAL_OUTCOME: "Critical Outcome\nWithin 12h At Triage",
    PaperTask.ED_REPRESENTATION: "ED Re-presentation\nWithin 72h",
}

results_dir = PROJECT_ROOT / "results"
results_fn = "mimic_old_ed_layer_6_do_0.3_recent_b0y9njtw"

all_ethos_result_dfs = {}

### Hospital Mortality

In [None]:
all_ethos_result_dfs[PaperTask.HOSPITAL_MORTALITY] = preprocess_inference_results(
    results_dir / Task.HOSPITAL_MORTALITY / results_fn,
    actual_expr=pl.col("actual").is_in([ST.DEATH]),
    expected_expr=pl.col("expected").is_in([ST.DEATH]),
    filter_ambiguous=(
        ~pl.col("actual").is_in([ST.TIMELINE_END]) & pl.col("stop_reason").is_in([Reason.GOT_TOKEN])
    ),
)

### ICU Admission

In [None]:
all_ethos_result_dfs[PaperTask.ICU_ADMISSION] = preprocess_inference_results(
    results_dir / Task.ICU_ADMISSION / results_fn,
    actual_expr=pl.col("actual").is_in([ST.ICU_ADMISSION]),
    expected_expr=pl.col("expected").is_in([ST.ICU_ADMISSION]),
    filter_ambiguous=(
        ~pl.col("actual").is_in([ST.TIMELINE_END]) & pl.col("stop_reason").is_in([Reason.GOT_TOKEN])
    ),
)

### Prolonged Stay

In [None]:
from datetime import timedelta


def get_los_quantile(q: float) -> timedelta:
    return (
        admissions_df.lazy()
        .with_columns(
            pl.col("admittime", "dischtime", "deathtime").str.to_datetime("%Y-%m-%d %H:%M:%S")
        )
        .select((pl.min_horizontal("dischtime", "deathtime") - pl.col("admittime")).quantile(q))
        .collect()
    ).item()


# we define the prolonged stay as everything longer than 90th percentile of all lengths of stay
print(get_los_quantile(0.9))
# we decided to round up the cuttoff to 10 days
prolonged_stay_cutoff = timedelta(days=10)

all_ethos_result_dfs[PaperTask.PROLONGED_STAY] = preprocess_inference_results(
    results_dir / PaperTask.HOSPITAL_MORTALITY / results_fn,
    actual_expr=pl.col("token_time") >= prolonged_stay_cutoff,
    expected_expr=pl.col("true_token_time") >= prolonged_stay_cutoff,
    filter_ambiguous=(
        ~pl.col("actual").is_in([ST.TIMELINE_END]) & pl.col("stop_reason").is_in([Reason.GOT_TOKEN])
    ),
)

### Composite: Mortality/ICU Admission/Prolonged Stay

In [None]:
all_ethos_result_dfs[PaperTask.COMPOSITE] = preprocess_inference_results(
    results_dir / PaperTask.ICU_ADMISSION / results_fn,
    actual_expr=pl.col("actual").is_in([ST.ICU_ADMISSION, ST.DEATH])
    | (pl.col("token_time") >= prolonged_stay_cutoff),
    expected_expr=pl.col("expected").is_in([ST.ICU_ADMISSION, ST.DEATH])
    | (pl.col("true_token_time") >= prolonged_stay_cutoff),
    filter_ambiguous=(
        ~pl.col("actual").is_in([ST.TIMELINE_END]) & pl.col("stop_reason").is_in([Reason.GOT_TOKEN])
    ),
)

### ED Hospitalization

In [None]:
all_ethos_result_dfs[PaperTask.ED_HOSPITALIZATION] = preprocess_inference_results(
    results_dir / PaperTask.ED_HOSPITALIZATION / results_fn,
    actual_expr=pl.col("actual").is_in([ST.ADMISSION]),
)

### ED Critical Outcome Within 12h

In [None]:
all_ethos_result_dfs[PaperTask.ED_CRITICAL_OUTCOME] = preprocess_inference_results(
    results_dir / PaperTask.ED_CRITICAL_OUTCOME / results_fn,
    actual_expr=pl.col("actual").is_in([ST.ICU_ADMISSION, ST.DEATH]),
    expected_expr=pl.col("expected") & (pl.col("true_token_time") <= pl.duration(hours=12)),
)

### ED Reattendance Within 72h

In [None]:
all_ethos_result_dfs[PaperTask.ED_REPRESENTATION] = preprocess_inference_results(
    results_dir / PaperTask.ED_REPRESENTATION / results_fn,
    actual_expr=pl.col("actual").is_in([ST.ED_ADMISSION]),
    expected_expr=pl.col("expected") & (pl.col("true_token_time") <= pl.duration(hours=72)),
)

## Preparation of MEDS-TAB Results

In [None]:
def get_meds_tab_results(task: PaperTask) -> pl.DataFrame:
    cols_mapping = {
        "subject_id": "patient_id",
        "boolean_value": "expected",
        "predicted_boolean_probability": "actual",
    }
    return pl.read_parquet(
        (PROJECT_ROOT / "results/baseline_meds_tab" / task).with_suffix(".parquet")
    ).rename(cols_mapping)[list(cols_mapping.values())]


all_meds_tab_result_dfs = {task: get_meds_tab_results(task) for task in PaperTask}

## Preparation of ED-Benchmark Processed Results

In [None]:
def extract_score_and_ci(s: str) -> dict:
    res, ci = s.split(" ")
    ci = ci[1:-1].split("-")
    return {
        "value": float(res),
        "ci_lower": float(ci[0]),
        "ci_upper": float(ci[1]),
    }


all_ed_bench_results = {
    task: [
        {
            label.lower(): extract_score_and_ci(value) if "(" in value else value
            for label, value in res.items()
            if label.lower() not in ["threshold", "runtime"]
        }
        for res in pl.read_csv(fp).to_dicts()
    ]
    for task in PaperTask
    if (fp := (PROJECT_ROOT / "results/baseline_ed_bench" / task).with_suffix(".csv")).exists()
}

## Functions for all computing metrics

In [None]:
from sklearn.metrics import average_precision_score, roc_auc_score, roc_curve

from ethos.metrics import compute_fitted_metrics

metric_names = {
    "auc": "AUROC",
    "auprc": "AUPRC",
    "sensitivity": "Sensitivity",
    "specificity": "Specificity",
}


def compute_standard_metrics(y_true, y_pred):
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    # find the point closest to (0, 1)
    best_idx = np.argmin(np.sqrt((fpr - 0) ** 2 + (tpr - 1) ** 2))
    return {
        "auc": roc_auc_score(y_true, y_pred),
        "auprc": average_precision_score(y_true, y_pred),
        "sensitivity": tpr[best_idx],
        "specificity": 1 - fpr[best_idx],
    }


def compute_metrics(df, n_bootstraps=n_bootstraps, use_fit=True):
    """Use `use_fir` when there is a low variety of y_probs to get the estimated curve."""
    metric_func = compute_fitted_metrics if use_fit else compute_standard_metrics
    results = {
        metric: value.item()
        for metric, value in metric_func(*df["expected", "actual"]).items()
        if metric in metric_names
    }
    results_subsampled = pl.DataFrame(
        {
            metric: value
            for metric, value in metric_func(
                *df.sample(fraction=1, with_replacement=True, seed=seed)["expected", "actual"]
            ).items()
            if metric in metric_names
        }
        for seed in range(n_bootstraps)
    )
    return {
        metric: {
            "value": results[metric],
            "ci_lower": results_subsampled[metric].quantile(0.025),
            "ci_upper": results_subsampled[metric].quantile(0.975),
            "bootstrap_values": results_subsampled[metric].to_list(),
        }
        for metric in metric_names.keys()
    }


def compute_cohen_d(col1: pl.Expr, col2: pl.Expr) -> pl.Expr:
    return (col1.mean() - col2.mean()) / (
        ((col1.var(ddof=1) * (col1.count() - 1)) + (col2.var(ddof=1) * (col2.count() - 1)))
        / ((col1.count() + col2.count()) - 2)
    ).sqrt()

## Evaluation of Tasks showcasing ARES

### Functions

In [None]:
ares_tasks = [
    PaperTask.HOSPITAL_MORTALITY,
    PaperTask.ICU_ADMISSION,
    PaperTask.PROLONGED_STAY,
    PaperTask.COMPOSITE,
]


def join(df: pl.DataFrame, other: pl.DataFrame) -> pl.DataFrame:
    return df.join(
        other,
        left_on="patient_id",
        right_on="subject_id",
        how="left",
    )


def compute_metrics_with_subgroup_break_down(result_dfs, **kwargs) -> pl.DataFrame:
    results_combined = []

    for task in ares_tasks:
        df = result_dfs[task]
        *auc_and_ci, _ = compute_metrics(df, **kwargs)["auc"].values()
        results_combined.append((task, "overall", None, *auc_and_ci))

        for group_name, df_group in (
            ("gender", df_gender_full),
            ("race", df_race_full),
        ):
            df_with_group = join(df, df_group)
            for subgroup in df_with_group["code"].unique().sort():
                df_subgroup = df_with_group.filter(code=subgroup)

                *auc_and_ci, auc_bootstraps = compute_metrics(df_subgroup, **kwargs)["auc"].values()
                results_combined.append((task, group_name, subgroup, *auc_and_ci))

    return pl.DataFrame(
        results_combined,
        schema=["task", "group", "subgroup", "auc", "ci_lower", "ci_upper"],
        orient="row",
    )

### Compute metrics for ARES tasks

In [None]:
# This runs about 15 minutes for 100 bootstraps (45 seconds for 5 bootstraps)
ethos_ares_results = compute_metrics_with_subgroup_break_down(all_ethos_result_dfs)

In [None]:
meds_tab_ares_results = compute_metrics_with_subgroup_break_down(
    all_meds_tab_result_dfs, use_fit=False
)

### Table for ETHOS vs MEDS-Tab

In [None]:
table = Table()
table.append(NoEscape(r"\centering"))

tabular = Tabular("l" + "c" * len(ares_tasks))
tabular.append(NoEscape(r"\toprule"))
tabular.add_row("", *[make_bold(task_titles[t]) for t in ares_tasks])
tabular.add_row(
    NoEscape(r"\textit{Prevalence} (\%)"),
    *[f"{all_ethos_result_dfs[task]['expected'].mean() * 100:.2f}" for task in ares_tasks],
)

for method, ares_results in [("ETHOS", ethos_ares_results), ("MEDS-Tab", meds_tab_ares_results)]:

    tabular.append(NoEscape(r"\toprule"))
    tabular.add_row(
        (NoEscape(r"\multicolumn{5}{c}{\textbf{\Large{" + method + "}}}"),), strict=False
    )
    tabular.append(NoEscape(r"\toprule"))

    rows = (
        ares_results.select(
            "task", "group", "subgroup", values=pl.concat_list("auc", "ci_lower", "ci_upper")
        )
        .pivot("task", index=["group", "subgroup"], values="values")
        .rows()
    )
    last_group = ""
    for group, subgroup, *task_results in rows:
        formatted_results = [f"{v:.3f} [{low:.3f}, {high:.3f}]" for v, low, high in task_results]

        if subgroup is None:
            first_cell = make_bold(group.title())
        else:
            if last_group != group:
                tabular.append(NoEscape(r"\midrule"))
                tabular.add_row(make_bold(group.title()), *[""] * len(formatted_results))
                last_group = group
            first_cell = NoEscape(r"\hspace{1em} " + subgroup)

        tabular.add_row(first_cell, *formatted_results)

tabular.append(NoEscape(r"\bottomrule"))

table.add_caption(
    NoEscape(
        r"\textbf{ETHOS performance on ARES tasks with a breakdown for demographic subgroups.}"
        r" This table presents the predictive performance (AUROC with 95\% confidence intervals) "
        "of ETHOS (top) and MEDS-Tab (bottom) for four critical clinical outcomes used in ARES: "
        "Hospital Mortality, ICU Admission, Prolonged Hospital Stay (>10 days), and a Composite "
        "Risk Score (HM+IA+PS). The prevalence rates of each outcome are provided for reference. "
        "Performance metrics are further stratified by gender and race to assess potential disparities "
        "in model performance across demographic subgroups."
    )
)
table.append(NoEscape(r"\label{tab:ares-results}"))
table.append(tabular)

print(table.dumps())

### Forest Plots

In [None]:
# Here, you might get some errors if the Robot font is not installed on your system

df = (
    ethos_ares_results.select(
        "task", "group", "subgroup", ETHOS=pl.concat_list("auc", "ci_lower", "ci_upper")
    )
    .join(
        meds_tab_ares_results.select(
            "task",
            "group",
            "subgroup",
            pl.concat_list("auc", "ci_lower", "ci_upper").alias("MEDS-Tab"),
        ),
        on=["task", "group", "subgroup"],
        join_nulls=True,
    )
    .with_columns(subgroup=pl.coalesce("subgroup", "group").str.to_titlecase())
)

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8), sharex=False)
axes = axes.ravel()
lw = 3

for i, ((task,), task_df) in enumerate(df.group_by("task", maintain_order=True)):
    ax = axes[i]

    subgroups = list(reversed(task_df["subgroup"]))
    y_positions = range(len(subgroups))
    ax.set_yticks(list(y_positions))
    ax.set_yticklabels(subgroups)

    for model, marker, color in [("MEDS-Tab", "s", gray_color), ("ETHOS", "D", orange_color)]:
        auc_vals, ci_lower, ci_upper = [], [], []

        for y, (m, lo, hi) in zip(y_positions, reversed(task_df[model])):
            ax.plot([lo, hi], [y, y], color=color, lw=lw, alpha=0.7)
            ax.plot([lo, lo], [y - 0.3, y + 0.3], color=color, lw=lw, alpha=0.7)
            ax.plot([hi, hi], [y - 0.3, y + 0.3], color=color, lw=lw, alpha=0.7)
            ax.plot(m, y, marker=marker, color=color, markersize=7, label=model if y == 0 else None)

    ax.set_title(task_titles[task])
    ax.grid(True)
    ax.legend()

fig.supxlabel("AUC score (95% CI)")
plt.tight_layout()

## Comparison of ETHOS and MEDS-Tab in subgroups of patients

In [None]:
from scipy.stats import bartlett


def get_auc_from_results(results: pl.DataFrame, model: str) -> pl.DataFrame:
    return (
        results.pivot("task", index=["group", "subgroup"], values="auc")
        .filter(pl.col("group") != "overall")
        .group_by("group")
        .agg(pl.exclude("subgroup"))
        .with_columns(model=pl.lit(model))
    )


(
    pl.concat(
        [
            get_auc_from_results(ethos_ares_results, "ETHOS"),
            get_auc_from_results(meds_tab_ares_results, "MEDS-Tab"),
        ]
    )
    .unpivot(index=["group", "model"], variable_name="task")
    .pivot("model", index=["task", "group"])
    .with_columns(
        pvalue=pl.struct("ETHOS", "MEDS-Tab").map_elements(
            lambda d: bartlett(d["ETHOS"], d["MEDS-Tab"])[1], return_dtype=pl.Float64
        )
    )
)

## Evaluation of ED Benchmark Tasks

### Functions

In [None]:
def gather_ed_bench_results(task: PaperTask) -> dict:
    results = {
        res["model"]: {metric: res[metric] for metric in metric_names.keys()}
        for res in all_ed_bench_results[task]
    }
    results["MEDS-Tab"] = compute_metrics(all_meds_tab_result_dfs[task], use_fit=False)
    results["ETHOS (ours)"] = compute_metrics(all_ethos_result_dfs[task], use_fit=True)
    return results


all_ed_bench_task_results = {}


def format_results(results: dict) -> str:
    return f"{results['value']:.3f} [{results['ci_lower']:.3f}, {results['ci_upper']:.3f}]"


def construct_latex_table(task: PaperTask, caption: str) -> str:
    all_ed_bench_task_results[task] = gather_ed_bench_results(task)
    return (
        pl.DataFrame(
            [
                (model, *[format_results(metric_results[metric]) for metric in metric_names.keys()])
                for model, metric_results in all_ed_bench_task_results[task].items()
            ],
            orient="row",
            schema=[" ", *metric_names.values()],
        )
        .to_pandas()
        .to_latex(
            index=False,
            column_format="lcccc",
            escape=True,
            caption=caption,
            label=f"tab:{task.replace("_", "-")}",
        )
    )

### Figure of ED Hospitalization

In [None]:
print(
    construct_latex_table(
        PaperTask.ED_HOSPITALIZATION,
        f"{make_bold('Prediction of Hospitalization At Triage.')} Performance comparison of various models for predicting hospitalization at triage, "
        r"evaluated using AUROC, AUPRC, sensitivity, and specificity (95\% confidence intervals in brackets). The thresholds for"
        " sensitivity and specificity were determined by finding the operating point on the ROC curve closest to (0,1). ETHOS "
        "demonstrates superior performance across all metrics, achieving the highest AUROC (0.912), AUPRC (0.887), sensitivity (0.849), "
        "and specificity (0.820), outperforming all other methods, including traditional scoring systems and machine learning models.",
    )
)

### Figure of Critical Outcome Within 12h

In [None]:
print(
    construct_latex_table(
        PaperTask.ED_CRITICAL_OUTCOME,
        f"{make_bold('Prediction of Critical Outcome Within 12h At Triage.')} Performance comparison of various models for predicting critical outcomes "
        r"within 12 hours of triage, evaluated using AUROC, AUPRC, sensitivity, and specificity (95\% confidence intervals in brackets). The thresholds "
        "for sensitivity and specificity were determined by finding the operating point on the ROC curve closest to (0,1). ETHOS achieves the "
        "highest performance across most of the metrics, with an AUROC of 0.937, AUPRC of 0.649, sensitivity of 0.858, and specificity of 0.863, substantially "
        "outperforming all other methods, including traditional scoring systems and machine learning models.",
    )
)

### Figure of ED Re-presentation Within 72h

In [None]:
print(
    construct_latex_table(
        PaperTask.ED_REPRESENTATION,
        f"{make_bold('Prediction of Emergency Department Re-presentation Within 72h.')} Performance comparison of various models for predicting emergency department"
        r" re-presentation within 72 hours, evaluated using AUROC, AUPRC, sensitivity, and specificity (95\% confidence intervals in brackets). The thresholds for "
        "sensitivity and specificity were determined by finding the operating point on the ROC curve closest to (0,1). ETHOS demonstrates superior performance, "
        "achieving the highest AUROC (0.740), AUPRC (0.199), sensitivity (0.659), and specificity (0.696), outperforming all other methods and showcasing its effectiveness for this challenging task.",
    )
)


### Forest Plots

In [None]:
n_rows, n_cols = 2, 2
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(10, 9))
lw = 3

for i, (task, results) in enumerate(all_ed_bench_task_results.items()):
    ax = axes.ravel()[i]
    y_positions = list(reversed(range(len(results))))

    results = {
        model: res["auc"]
        for model, res in sorted(results.items(), key=lambda x: x[1]["auc"]["value"])
    }

    for y, auc_score in zip(y_positions, results.values()):
        m, lo, hi = auc_score["value"], auc_score["ci_lower"], auc_score["ci_upper"]

        ax.plot([lo, hi], [y, y], color=orange_color, lw=lw)
        ax.plot([lo, lo], [y - 0.3, y + 0.3], color=orange_color, lw=lw)
        ax.plot([hi, hi], [y - 0.3, y + 0.3], color=orange_color, lw=lw)
        ax.plot(m, y, marker="D", color=orange_color, markersize=6)

    ax.set_yticks(y_positions)
    ax.set_yticklabels(results.keys())
    ax.grid(True)
    ax.set_title(task_titles[task])

while (i := i + 1) < n_rows * n_cols:
    ax = axes.flatten()[i].set_visible(False)

fig.supxlabel("AUC score (95% CI)")
plt.tight_layout()

## Tokenized MIMIC dataset statistics

In [None]:
# This part uses tokenized MIMIC dataset, refer to README for instructions how to run it
dataset_dir = PROJECT_ROOT / "data/tokenized_datasets/mimic_ed"

test_counts = pl.read_csv(dataset_dir / "test/code_counts.csv")
train_counts = pl.read_csv(dataset_dir / "train/code_counts.csv")
total_counts = pl.concat([test_counts, train_counts]).group_by("code").agg(pl.sum("count"))

### Simple PHT Stats

In [None]:
from ethos.datasets import TimelineDataset

lengths = []
for fold in ["train", "test"]:
    patient_offsets = TimelineDataset(dataset_dir / fold).patient_offsets.numpy()
    lengths.append(pl.Series(patient_offsets[1:] - patient_offsets[:-1]))
train_pht_lengths, test_pht_lengths = lengths

In [None]:
def eval_timeline_lengths(series, func):
    return func(series)


timeline_lengths_exprs = {
    ("Tokens", None): lambda s: s.sum(),
    ("Timelines", None): lambda s: s.len(),
    ("Timeline Lengths", "Longest"): lambda s: s.max(),
    ("Timeline Lengths", "Q3"): lambda s: s.quantile(0.75),
    ("Timeline Lengths", "Median"): lambda s: s.median(),
    ("Timeline Lengths", "Mean"): lambda s: s.mean(),
    ("Timeline Lengths", "Q1"): lambda s: s.quantile(0.25),
    ("Timeline Lengths", "Shortest"): lambda s: s.min(),
    ("Timeline Lengths", "Unique"): lambda s: s.n_unique(),
}

timeline_lengths_dfs = [
    train_pht_lengths,
    test_pht_lengths,
    pl.concat([train_pht_lengths, test_pht_lengths]),
]

In [None]:
def eval_code_counts(df, func):
    return df.select(func(pl.col("code")).sum()).item()


filtering_exprs = {
    ("Unique Timeline Tokens", None): lambda s: s.len(),
    ("Timeline Tokens Encoding", "Time Intervals"): lambda s: (
        s.str.starts_with("=") | s.str.contains("-")
    )
    & ~s.str.contains("//"),
    ("Timeline Tokens Encoding", "Quantiles"): lambda s: s.str.starts_with("Q")
    & ~s.str.contains("//"),
    ("Timeline Tokens Encoding", "Medications"): lambda s: s.str.starts_with("ATC//"),
    ("Timeline Tokens Encoding", "Diagnoses"): lambda s: s.str.starts_with("ICD//CM"),
    ("Timeline Tokens Encoding", "Procedures"): lambda s: s.str.starts_with("ICD//PCS"),
    ("Timeline Tokens Encoding", "Labs"): lambda s: s.str.starts_with("LAB//"),
    ("Timeline Tokens Encoding", "Vitals"): lambda s: s.str.starts_with("VITAL//"),
    ("Timeline Tokens Encoding", "HCPCS"): lambda s: s.str.starts_with("HCPCS//"),
    ("Timeline Tokens Encoding", "Inpatient Stays"): lambda s: s.str.starts_with("HOSPITAL_")
    | s.str.starts_with("ICU_")
    | s.str.starts_with("DISCHARGE_")
    | s.str.starts_with("INSURANCE")
    | s.str.starts_with("ADMISSION_"),
    ("Timeline Tokens Encoding", "Emergency Department"): lambda s: s.str.starts_with("ED_"),
    ("Timeline Tokens Encoding", "DRGs"): lambda s: s.str.starts_with("DRG"),
    ("Timeline Tokens Encoding", "BMI"): lambda s: s.str.starts_with("BMI//"),
}

code_count_dfs = [train_counts, test_counts, total_counts]

In [None]:
general_pht_numbers_df = pl.DataFrame(
    [
        (*label, *[func(df, expr) for df in dfs])
        for func, exprs, dfs in [
            (eval_timeline_lengths, timeline_lengths_exprs, timeline_lengths_dfs),
            (eval_code_counts, filtering_exprs, code_count_dfs),
        ]
        for label, expr in exprs.items()
    ],
    schema={"group": str, "subgroup": str, "train": int, "test": int, "total": int},
    orient="row",
)

table = Table()
tabular = Tabular("lrrr")
tabular.append(NoEscape(r"\toprule"))
tabular.add_row("", *[make_bold(s) for s in split_titles.values()])
tabular.append(NoEscape(r"\toprule"))

last_group = ""
for group, subgroup, *values in general_pht_numbers_df.rows():

    if subgroup is None:
        first_cell = make_bold(group)
        if last_group != "":
            tabular.append(NoEscape(r"\midrule"))
    else:
        if last_group != group:
            if last_group != "":
                tabular.append(NoEscape(r"\midrule"))
            tabular.add_row((make_bold(group),), strict=False)
        first_cell = NoEscape(r"\hspace{1em} " + escape_latex(subgroup))

    tabular.add_row(first_cell, *(f"{v:,}" for v in values))
    last_group = group

tabular.append(NoEscape(r"\bottomrule"))
table.append(NoEscape(r"\centering"))

table.add_caption(
    NoEscape(
        r"\textbf{Summary of Token and Timeline Statistics.} "
        "This table presents a comprehensive overview of the token and timeline data in the training, test, and combined datasets."
        " Key metrics include the total number of tokens and timelines, along with statistics on timeline lengths such as the longest timeline, median, mean, and shortest timeline."
        " The number of unique timeline tokens is also reported. The final section breaks down the encoding of timeline tokens into categories, such as time intervals, "
        "quantiles, medications, diagnoses, procedures, laboratory results, vitals, and other clinical features. "
        "This summary highlights the diversity and complexity of the tokenized data used in the study."
    )
)
table.append(NoEscape(r"\label{tab:simple-pht-stats}"))
table.append(tabular)

print(table.dumps())

### Detailed Token Statistics

In [None]:
count_col, unique_col = "Count", "#Unique"


def get_token_contribution(df: pl.DataFrame) -> pl.DataFrame:
    return (
        df.group_by(
            pl.when(
                pl.col("code").str.starts_with("ATC")
                & ~pl.col("code").str.starts_with("ATC//4//")
                & ~pl.col("code").str.starts_with("ATC//SFX//")
            )
            .then(pl.lit("ATC"))
            .when(pl.col("code").str.slice(0, 3).is_in(["ICD", "ATC"]))
            .then(
                pl.col("code").str.slice(0, 3)
                + pl.lit("_")
                + pl.col("code").str.split("//").list.get(1, null_on_oob=True)
            )
            .otherwise(pl.col("code").str.split("//").list.get(0))
            .alias("code")
        )
        .agg(pl.sum("count").alias(count_col), pl.count("count").alias(unique_col))
        .sort(count_col, descending=True)
    )


fold_order = {v: i for i, v in enumerate(["Train", "Test", "Total"])}
count_results_df = (
    pl.concat(
        [
            get_token_contribution(df).with_columns(split=pl.lit(label))
            for label, df in zip(split_titles.values(), [train_counts, test_counts, total_counts])
        ]
    )
    .to_pandas()
    .pivot(index="code", columns="split", values=[count_col, unique_col])
    .swaplevel(axis=1)
    .sort_index(axis=1, key=lambda index: index.map(fold_order), level=0)
    .sort_values(("Total", count_col), ascending=False)
)

count_results_df.columns.names = ["", ""]
count_results_df.index.name = "Code Group"

print(
    count_results_df.map(lambda s: f"{s:,}").to_latex(
        column_format="l" + "c" * len(count_results_df.columns),
        label="tab:token-stats",
        caption="\\textbf{Token Statistics}. The table provides a detailed breakdown of the total number"
        " of tokens and unique tokens for each code group in the training, test, and combined datasets."
        " Each code group represents a specific type of information, such as laboratory results (LAB),"
        r" clinical classifications (e.g., ATC, ICD\_CM), time intervals (e.g., 15m-45m, 12h-18h), and"
        " other key features like BMI, vitals, or discharge locations. The statistics summarize the diversity"
        r" (\#Unique) and frequency (Count) of tokens across datasets, offering insights into the distribution"
        " and variability of features used in the modeling process.",
        multicolumn_format="c",
        escape=True,
        longtable=True,
    )
)

## Data sources used to generate the datasets

In [None]:
import yaml

with (PROJECT_ROOT / "scripts/meds/mimic/configs/event_configs-ed.yaml").open("r") as f:
    event_configs = yaml.safe_load(f)


def load_column_names(fp):
    return (pl.read_parquet if fp.suffix == ".parquet" else pl.read_csv)(fp, n_rows=0).columns


columns_dont_really_used = ["drg_severity", "priority", "drg_mortality", "language", "emar_seq"]


def filter_columns(columns, event_config):
    processed_event_config = [
        c[4:-1] if c.startswith("col(") else c
        for event_values in event_config.values()
        for col_code, col in event_values.items()
        if col_code != "time_format" and col is not None
        for c in (col if isinstance(col, list) else [col])
    ]
    return [
        col
        for col in processed_event_config
        if col in columns and col not in columns_dont_really_used
    ]


data_src_df = pl.DataFrame(
    [
        (*s.split("/"), filter_columns(load_column_names(fps[0]), event_configs[s]))
        for s in sorted(event_configs.keys())
        if (
            fps := [
                fp
                for sfx in [".parquet", ".csv.gz"]
                if (fp := (mimic_dir / s).with_suffix(sfx)).exists()
            ]
        )
    ],
    schema=["group", "supgroup", "columns"],
    orient="row",
)
data_src_df

In [None]:
table = Table()
tabular = Tabular("ll")
tabular.append(NoEscape(r"\toprule"))
tabular.add_row("Data Source", "Used Columns")
tabular.append(NoEscape(r"\toprule"))

last_group = ""
for group, subgroup, columns in data_src_df.rows():

    if subgroup is None:
        first_cell = make_bold(group)
    else:
        if last_group != group:
            if last_group != "":
                tabular.append(NoEscape(r"\midrule"))
            tabular.add_row((make_bold(group), ""))
            last_group = group
        first_cell = NoEscape(r"\hspace{1em} " + escape_latex(subgroup))
        tabular.append(NoEscape(r"\vspace{0.2em}"))

    tabular.add_row(
        first_cell,
        NoEscape(
            r"\makecell[l]{"
            + r"\\".join(
                [escape_latex(", ".join(columns[i : i + 3])) for i in range(0, len(columns), 3)]
            )
            + r"}"
        ),
    )

tabular.append(NoEscape(r"\bottomrule"))
table.append(NoEscape(r"\centering"))

table.add_caption(
    NoEscape(
        r"\textbf{Overview of the data sources and their corresponding columns used in this work from the MIMIC-IV database and its extension MIMIC-IV-ED.}"
        " The table groups the data into three main categories: ED (Emergency Department), hosp (Hospital), and ICU (Intensive Care Unit)."
        " For each category, the associated tables and the specific columns extracted for the study are listed,"
        r" highlighting key variables relevant to patient care and outcomes, such as identifiers (e.g., stay\_id, hadm\_id),"
        " timestamps (e.g., intime, charttime), and clinical observations (e.g., vitalsign, labresults)."
        " These selections were guided by the objectives of the study to comprehensively model patient trajectories and outcomes."
    )
)
table.append(NoEscape(r"\label{tab:data-sources}"))
table.append(tabular)

print(table.dumps())

## Figure of AUROC for all tasks

In [None]:
def get_proper_attributes(score_name: str) -> tuple[str, str, str]:
    match score_name.lower():
        case "auc":
            return "AUC", "fpr_values", "tpr_values"
        case "auprc":
            return "AUPRC", "recall_values", "precision_values"
        case _:
            raise ValueError(f"Unknown score name: {score_name}")


def compute_ci(df: pl.DataFrame, score_type: str = "auc", n_bootstraps: int = n_bootstraps):
    mean_fpr = np.linspace(0, 1, 100)
    aucs, tprs = [], []
    _, x_attr, y_attr = get_proper_attributes(score_type)
    for seed in range(n_bootstraps):
        res_fit = compute_fitted_metrics(
            *df["expected", "actual"].sample(fraction=1, with_replacement=True, seed=seed)
        )
        aucs.append(res_fit[score_type])
        tprs.append(np.interp(mean_fpr, res_fit[x_attr], res_fit[y_attr]))
        tprs[-1][0] = 0.0
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_tpr, std_tpr = np.mean(tprs, axis=0), np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    return np.percentile(aucs, [2.5, 97.5]), mean_fpr, tprs_lower, tprs_upper


def plot_auc(
    df: pl.DataFrame,
    score_type: str = "auc",
    title: str = "",
    text_upper: bool = True,
    n_bootstraps=n_bootstraps,
):
    res_fit = compute_fitted_metrics(df["expected"], df["actual"])

    score_name, x_attr, y_attr = get_proper_attributes(score_type)

    if score_type == "auc":
        (ci_lower, ci_upper), *ci_boundries = compute_ci(df, score_type, n_bootstraps=n_bootstraps)
        plt.fill_between(
            *ci_boundries,
            color=black_color,
            alpha=0.5,
            label="95% Confidence Interval",
        )

    plt.scatter(
        res_fit[x_attr],
        res_fit[y_attr],
        marker="X",
        color=black_color,
        s=100,
        label="Unique Thresholds",
    )

    plt.plot(
        res_fit[x_attr],
        res_fit[y_attr],
        color=orange_color,
        lw=5,
        label=f"{score_name} curve",
    )

    plt.grid(visible=False)
    plt.gca().set(ylim=(-0.01, 1.01), xlim=(-0.01, 1.01), title=title)

    text = [f"{score_name}: {res_fit[score_type]:.3f}"]
    if score_type == "auc":
        text.append(f"{score_name} CI: [{ci_lower:.3f}, {ci_upper:.3f}]")
    text.append(f"N: {len(df):,} ({df['expected'].mean():.1%} pos.)")

    plt.gca().add_artist(
        AnchoredText(
            "\n".join(text),
            loc=f"{'upper' if text_upper else 'lower'} right",
            pad=0.1,
            borderpad=0.1,
            frameon=False,
            prop=dict(size=font_size, color=black_color),
        )
    )


def plot_ethos_curves(score_type: str = "auc", text_upper: tuple[int] = (), **kwargs):
    n_rows, n_cols = 2, 4
    size = 18
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(size, size * n_rows / n_cols))

    for i, (task, df) in enumerate(all_ethos_result_dfs.items()):
        ax: plt.Axes = axes[i // n_cols, i % n_cols]
        plt.sca(ax)

        plot_auc(
            df, score_type=score_type, title=task_titles[task], text_upper=i in text_upper, **kwargs
        )

        if i % n_cols != 0:
            ax.set_ylabel("")
            ax.set_yticks([])

        if i < n_cols:
            ax.set_xlabel("")
            ax.set_xticks([])
    else:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            handles[::-1],
            labels[::-1],
            loc="upper left",
            frameon=False,
            fontsize=font_size,
            bbox_to_anchor=(1.05, 0.8),
        )
        while (i := i + 1) < n_rows * n_cols:
            ax = axes.flatten()[i].set_visible(False)

    fig.subplots_adjust(wspace=0.1, hspace=0.2)

In [None]:
plot_ethos_curves(score_type="auc", n_bootstraps=n_bootstraps)

In [None]:
plot_ethos_curves(score_type="auprc", text_upper=(0, 6))

## Figure of Calibration Curves for all tasks

In [None]:
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss

n_rows, n_cols = 2, 4
size = 18
fig, axes = plt.subplots(n_rows, n_cols, figsize=(size, size * n_rows / n_cols))

n_bins = 10

for i, (task, df) in enumerate(all_ethos_result_dfs.items()):
    plt.sca(ax := axes[i // n_cols, i % n_cols])

    frac_pos, mean_pred = calibration_curve(*df["expected", "actual"], n_bins=n_bins)

    bootstrapped_fracs = np.zeros((n_bootstraps, len(mean_pred)))
    for seed in range(n_bootstraps):
        frac_bs, mean_pred_bs = calibration_curve(
            *df["expected", "actual"].sample(fraction=1, with_replacement=True, seed=seed),
            n_bins=n_bins,
        )
        bootstrapped_fracs[seed] = (
            frac_bs
            if len(frac_bs) == len(frac_pos)
            else np.interp(mean_pred, mean_pred_bs, frac_bs)
        )

    ci_lower, ci_upper = np.percentile(bootstrapped_fracs, [2.5, 97.5], axis=0)
    plt.fill_between(
        mean_pred,
        ci_lower,
        ci_upper,
        color=gray_color,
        label="95% Confidence Interval",
    )

    plt.plot([0, 1], [0, 1], linestyle="--", color=black_color, label="Perfect Calibration")
    plt.plot(
        mean_pred,
        frac_pos,
        color=orange_color,
        lw=5,
        label="ETHOS Calibration",
    )
    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])
    plt.title(task_titles[task])
    plt.grid(False)

    at = AnchoredText(
        f"Brier score: {brier_score_loss(df['expected'], df['actual']):.3f}",
        loc="lower right",
        pad=0,
        borderpad=0.1,
        frameon=False,
        prop=dict(size=font_size, color=black_color),
    )
    ax.add_artist(at)

    if i % n_cols != 0:
        ax.set_ylabel("")
        ax.set_yticks([])

    if i < n_cols:
        ax.set_xlabel("")
        ax.set_xticks([])
else:
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles[::-1],
        labels[::-1],
        loc="center right",
        frameon=False,
        fontsize=font_size,
        bbox_to_anchor=(2.25, 0.8),
    )
    while (i := i + 1) < n_rows * n_cols:
        ax = axes.flatten()[i].set_visible(False)

fig.supxlabel("Mean predicted probability")
fig.supylabel("Fraction of positives", x=0.07)
fig.subplots_adjust(wspace=0.1, hspace=0.2)