In [None]:
from pathlib import Path
import pandas as pd
from src.data import load_data, filter_notes, filter_patients, compute_labels, TraumaDataset
from tableone import TableOne
from tqdm import tqdm
from pqdm.processes import pqdm

data_dir = Path("/gpfs/data/benjamin-lab/Trauma_LLM/data")

In [None]:
src = load_data(data_dir)
patients = filter_patients(
    registry=src["registry"],
    demo=src["demo"],
    disch=src["disch"],
)

In [None]:
# low number of tertiary notes from source, compared to total number of records in registry
print("Unique patients in registry", src["registry"]["Medical Record #"].nunique())
non_dupe_mrns = src["registry"]["Medical Record #"].drop_duplicates(keep=False)
print("Registry patients with only one stay", len(non_dupe_mrns))
print("Registry patients with tertiary note (note may come from any encounter)", src["tert"].loc[src["tert"]["MRN"].isin(non_dupe_mrns), "MRN"].nunique())

In [None]:
def describe_data(window):
    notes = filter_notes(
        patients=patients,
        window=window,
        note_types=["hp", "op", "tert"],
        **src,
    )

    print("Count number of patients")
    print(notes["MRN"].nunique())
    print()

    print("Count number of notes by note type")
    print(notes.groupby("SOURCE").size())
    print()

    print("Count number of patients with a given note type")
    print(notes.groupby("SOURCE")["MRN"].nunique())
    print()

    df = patients[patients["MRN"].isin(notes["MRN"])].sort_values("MRN").reset_index(drop=True)
    temp = notes.groupby("MRN")["SOURCE"].value_counts().reset_index()
    note_counts = temp.pivot(index="MRN", columns="SOURCE", values="count").fillna(0).astype(int)
    note_counts.columns = sorted([f"NOTE ({x.upper()})" for x in note_counts.columns])
    note_counts = note_counts.sort_values("MRN").reset_index(drop=True)
    labels, _ = compute_labels(df)

    df["AGE"] = (df["ED_ARRIVAL_DTTM"] - df["DOB"]).dt.days / 365
    df["LOS"] = (df["HSP_DC_DTTM"] - df["ED_ARRIVAL_DTTM"]).dt.days
    df["MORTALITY"] = labels["hospital_mortality"].replace({0: "Alive", 1: "Deceased"})
    df["ISS (TERCILE)"] = labels["iss_tercile"].replace(
        {0: "[0,25]", 1: "(25,50]", 2: "(50,75]"}
    )
    df["SEX"] = df["GENDER"]
    df["RACE"] = (
        df["RACE"]
        .fillna("Unknown")
        .replace(
            {
                "White": "White",
                "Asian/Mideast Indian": "Asian",
                "Black or African-American": "Black",
                "None of the above": "Other",
                "Unknown or Patient unable to respond": "Unknown",
                "More than one Race": "Other",
                "Patient declines to respond": "Unknown",
                "Other Asian": "Asian",
                "American Indian or Alaska Native": "AIAN",
                "Native Hawaiian/Other Pacific Islander": "NHPI",
                "Asian Indian": "Asian",
                "Other Pacific Islander": "NHPI",
                "Filipino": "Asian",
                "Chinese": "Asian",
            }
        )
    )
    df["ETHNICITY"] = (
        df["ETHNICITY"]
        .fillna("Unknown")
        .replace(
            {
                "Not Hispanic, Latino/a, or Spanish origin": "Not Hispanic or Latino",
                "Patient declines to respond": "Unknown",
                "Hispanic or Latino": "Hispanic or Latino",
                "Unknown or Patient unable to respond": "Unknown",
                "Other Hispanic, Latino/a, or Spanish origin": "Hispanic or Latino",
                "Mexican, Mexican American, or Chicano/a": "Hispanic or Latino",
                "Cuban": "Hispanic or Latino",
                "Puerto Rican": "Hispanic or Latino",
            }
        )
    )

    data = pd.concat([df, note_counts], axis=1)
    tb1s = {
        k: TableOne(
            data=data,
            columns=["GENDER", "RACE", "ETHNICITY", "AGE", "MORTALITY", "ISS", "ISS (TERCILE)", "LOS", "NOTE (HP)", "NOTE (OP)", "NOTE (TERT)"],
            # categorical=["GENDER", "RACE", "ETHNICITY", "MORTALITY", "ISS (TERCILE)"],
            # continuous=["AGE", "ISS", "LOS", "NOTE (HP)", "NOTE (OP)", "NOTE (TERT)"],
            # nonnormal=["AGE", "ISS", "LOS", "NOTE (HP)", "NOTE (OP)", "NOTE (TERT)"],
            categorical=["GENDER", "RACE", "ETHNICITY", "MORTALITY", "ISS (TERCILE)", "NOTE (HP)", "NOTE (OP)", "NOTE (TERT)"],
            continuous=["AGE", "ISS", "LOS"],
            nonnormal=["AGE", "ISS", "LOS"],
            groupby=k,
            missing=False,
            decimals = {
                "ISS": 0,
                "LOS": 0,
                # "NOTE (HP)": 0,
                # "NOTE (OP)": 0,
                # "NOTE (TERT)": 0,
            },
            order = {
                "ISS (TERCILE)": ["[0,25]","(25,50]", "(50,75]"]
            },
        )
        for k in ["MORTALITY", "ISS (TERCILE)"]
    }
    return tb1s

In [None]:
tb1s_24h = describe_data(window=24)

In [None]:
tb1s_48h = describe_data(window=48)

In [None]:
tb1s_48h["ISS (TERCILE)"].tableone

In [None]:
tb1s_48h["ISS (TERCILE)"].tableone.loc[["NOTE (HP), n (%)", "NOTE (OP), n (%)", "NOTE (TERT), n (%)"]]

In [None]:
tb1s_24h["MORTALITY"].to_csv("figs/tab1-mort-24h.csv")
tb1s_24h["ISS (TERCILE)"].to_csv("figs/tab1-iss-24h.csv")
tb1s_48h["MORTALITY"].to_csv("figs/tab1-mort-48h.csv")
tb1s_48h["ISS (TERCILE)"].to_csv("figs/tab1-iss-48h.csv")

In [None]:
ds24h = TraumaDataset(
    data_dir=data_dir,
    window=24,
    note_types=["hp", "op", "tert"],
    tokenizer_name="whaleloops/clinicalmamba-130m-hf",
    context_length=16384,
    debug_len=True,
)

ds48h = TraumaDataset(
    data_dir=data_dir,
    window=48,
    note_types=["hp", "op", "tert"],
    tokenizer_name="whaleloops/clinicalmamba-130m-hf",
    context_length=16384,
    debug_len=True,
)

In [None]:
def get_24h_len(x):
    return ds24h[x]
lens24h = pqdm(list(range(len(ds24h))), get_24h_len, n_jobs=8)

def get_48h_len(x):
    return ds48h[x]
lens48h = pqdm(list(range(len(ds48h))), get_48h_len, n_jobs=8)

In [None]:
pd.Series(lens24h).describe(percentiles=[.25, .50, .75, .80, .90, .95, .96, .97, .98, .985, .99])

In [None]:
pd.Series(lens48h).describe(percentiles=[.25, .50, .75, .80, .90, .95, .96, .97, .98, .985, .99])