# Potential Organ Donor Identifier

# Setup

In [None]:
import sys 
import os
import polars as pl 
import matplotlib.pyplot as plt
project_root = os.path.dirname(os.getcwd())
sys.path.insert(0, project_root)
from utils.config import config
from utils.io import read_data
from utils.strobe_diagram import create_consort_diagram
from clifpy.utils.stitching_encounters import stitch_encounters
import gc

In [None]:
site_name = config['site_name']
tables_path = config['tables_path']
file_type = config['file_type']
print(f"Site Name: {site_name}")
print(f"Tables Path: {tables_path}")
print(f"File Type: {file_type}")

In [None]:
strobe_counts = {}

# Load data

In [None]:
# read required tables
adt_filepath = f"{tables_path}/clif_adt.{file_type}"
hospitalization_filepath = f"{tables_path}/clif_hospitalization.{file_type}"
patient_filepath = f"{tables_path}/clif_patient.{file_type}"
resp_filepath = f"{tables_path}/clif_respiratory_support.{file_type}"
labs_filepath = f"{tables_path}/clif_labs.{file_type}"
micro_culture_filepath = f"{tables_path}/clif_microbiology_culture.{file_type}"
crrt_filepath = f"{tables_path}/clif_crrt_therapy.{file_type}"
hospial_dx_filepath = f"{tables_path}/clif_hospital_diagnosis.{file_type}"

adt_df = read_data(adt_filepath, file_type)
hospitalization_df = read_data(hospitalization_filepath, file_type)
patient_df = read_data(patient_filepath, file_type)
resp_df = read_data(resp_filepath, file_type)
labs_df = read_data(labs_filepath, file_type)
micro_culture = read_data(micro_culture_filepath, file_type)
crrt_therapy = read_data(crrt_filepath, file_type)
hospital_dx = read_data(hospial_dx_filepath, file_type)

total_patients = hospitalization_df["patient_id"].n_unique()
strobe_counts["0_all_patients"] = total_patients

# Create final_df with just patient_id for all unique patients
final_df = hospitalization_df.select("patient_id").unique()
strobe_counts

# Stitch encounters

In [None]:
hosp_stitched, adt_stitched, encounter_mapping = stitch_encounters(
      hospitalization=hospitalization_df.to_pandas(),
      adt=adt_df.to_pandas(),
      time_interval=12
  )

hosp_stitched = pl.from_pandas(hosp_stitched)
adt_stitched = pl.from_pandas(adt_stitched)
encounter_mapping = pl.from_pandas(encounter_mapping)

gc.collect()

# Identify decedents

In [None]:
# Step 1: Get the admission dates for expired and hospice discharges

# Identify expired encounters
expired_encounters_df = hosp_stitched.filter(
    pl.col('discharge_category').str.to_lowercase() == 'expired'
)

# Identify hospice encounters
hospice_encounters_df = hosp_stitched.filter(
    pl.col('discharge_category').str.to_lowercase() == 'hospice'
)

# Make hospitalization subset for expired
expired_hospitalizations = (
    expired_encounters_df
    .select([
        'patient_id',
        'hospitalization_id',
        'encounter_block',
        'admission_dttm',
        'discharge_dttm'  # discharge datetime for the death hospitalization
    ])
    .unique()
)

# Make hospitalization subset for hospice
hospice_hospitalizations = (
    hospice_encounters_df
    .select([
        'patient_id',
        'hospitalization_id',
        'encounter_block',
        'admission_dttm',
        'discharge_dttm'  # discharge datetime for the hospice hospitalization
    ])
    .unique()
)

expired_patients_n = expired_hospitalizations["patient_id"].n_unique()
strobe_counts["1_expired_patients_n"] = expired_patients_n

# Create is_expired flag
expired_patients_flag = expired_hospitalizations.select(["patient_id"]).unique()
expired_patients_flag = expired_patients_flag.with_columns(
    pl.lit(True).alias("is_expired")
)

# Create is_hospice flag
hospice_patients_flag = hospice_hospitalizations.select(["patient_id"]).unique()
hospice_patients_flag = hospice_patients_flag.with_columns(
    pl.lit(True).alias("is_hospice")
)

# Add is_expired (default False if not expired)
final_df = final_df.join(expired_patients_flag, on="patient_id", how="left")
final_df = final_df.with_columns(
    pl.col("is_expired").fill_null(False)
)

# Add is_hospice (default False if not hospice)
final_df = final_df.join(hospice_patients_flag, on="patient_id", how="left")
final_df = final_df.with_columns(
    pl.col("is_hospice").fill_null(False)
)

strobe_counts

# IMV

In [None]:
# Step 2- On invasive mechanical ventilation at or within 48h of death.

# Expired patients ever on IMV
imv_resp_encounters = resp_df.filter(pl.col("device_category").str.to_lowercase() == "imv")
imv_expired = expired_hospitalizations.join(
    imv_resp_encounters.select(["hospitalization_id", "recorded_dttm"]), 
    on="hospitalization_id", how="inner"
)

resp_expired_latest_recorded_imv = (
    imv_expired
    .sort("recorded_dttm", descending=True)
    .group_by("patient_id")
    .agg(pl.all().first())
)

resp_expired_imv_hrs = resp_expired_latest_recorded_imv.with_columns(
    (
        (pl.col("discharge_dttm") - pl.col("recorded_dttm")).dt.total_seconds() / 3600
    ).alias("hr_2death_last_imv")
)

# Filter to patients who were on IMV at death or before 48hrs of death 
resp_expired_cohort = resp_expired_imv_hrs.filter(pl.col('hr_2death_last_imv')<=48)

imv_expired_patients = imv_expired["patient_id"].n_unique()
imv_after_expire = resp_expired_imv_hrs.filter(pl.col('hr_2death_last_imv') <= 0)["patient_id"].n_unique()
imv_48hr_expire = resp_expired_cohort["patient_id"].n_unique()
strobe_counts["2_imv_48hr_expire"] = imv_48hr_expire
strobe_counts["2_imv_after_expire"] = imv_after_expire
strobe_counts["2_imv_expired_patients"] = imv_expired_patients

# Add imv_48hr_expire flag to final_df: True if patient_id in resp_expired_cohort, else False
imv_48hr_expire_patients = resp_expired_cohort.select(["patient_id"]).unique()
imv_48hr_expire_patients = imv_48hr_expire_patients.with_columns(
    pl.lit(True).alias("imv_48hr_expire")
)

final_df = final_df.join(imv_48hr_expire_patients, on="patient_id", how="left")
final_df = final_df.with_columns(
    pl.col("imv_48hr_expire").fill_null(False)
)

strobe_counts

# ICD Codes

The CALC criteria includes the following as cause:
- I20–I25: ischemic heart disease
- I60–I69: cerebrovascular disease
- V01–Y89: external causes (e.g., blunt trauma, gunshot wounds, overdose, suicide, drowning, asphyxiation)

[Reference](https://www.cms.gov/files/document/112020-opo-final-rule-cms-3380-f.pdf)

We also flag contraindications of sepsis and cancer using ICD10 codes.

In [None]:
import polars as pl

# ---- 1) Normalize codes and create ICD-10 cause flags per diagnosis row ----
# Assumes hospital_dx has at least: hospitalization_id, diagnosis_code, diagnosis_code_format
hospital_dx_flags = (
    hospital_dx
    .with_columns([
        pl.col("diagnosis_code")
            .cast(pl.Utf8)
            .str.to_lowercase()
            .str.replace_all(r"[.\s]", "")
            .alias("dx_norm"),
        pl.col("diagnosis_code_format")
            .cast(pl.Utf8)
            .str.to_lowercase()
            .alias("sys")
    ])
    .with_columns([
        # I20–I25: ischemic heart disease
        pl.when(pl.col("sys").is_in(["icd10", "icd10cm"]))
            .then(pl.col("dx_norm").str.contains(r"^i2[0-5]\w*$"))
            .otherwise(False)
            .alias("icd10_ischemic"),

        # I60–I69: cerebrovascular disease
        pl.when(pl.col("sys").is_in(["icd10", "icd10cm"]))
            .then(pl.col("dx_norm").str.contains(r"^i6[0-9]\w*$"))
            .otherwise(False)
            .alias("icd10_cerebro"),

        # V01–Y89: external causes (exclude Y90–Y99)
        # Matches: v01–v99, w00–w99, x00–x99, y00–y89
        pl.when(pl.col("sys").is_in(["icd10", "icd10cm"]))
            .then(
                pl.col("dx_norm").str.contains(
                    r"^(v0[1-9]\d?|v[1-9]\d|w\d{2}|x\d{2}|y0\d|y1\d|y2\d|y3\d|y4\d|y5\d|y6\d|y7\d|y8\d)\w*$"
                )
            )
            .otherwise(False)
            .alias("icd10_external"),

        # ---Sepsis (A40*, A41*, plus specific organism/context codes) ---
        pl.when(pl.col("sys").is_in(["icd10", "icd10cm"]))
            .then(
                pl.any_horizontal([
                    pl.col("dx_norm").str.starts_with("a40"),
                    pl.col("dx_norm").str.starts_with("a41"),
                    pl.col("dx_norm").is_in([
                        "a021", "a327", "a392", "a207", "a427", "a267", "a5486",
                        "b377", "o85", "o8604", "r6520", "r6521"
                    ]),
                    pl.col("dx_norm").str.contains(r"^p36\w*$"),      # neonatal sepsis
                    pl.col("dx_norm").str.contains(r"^t8144\w*$"),    # sepsis following a procedure
                ])
            )
            .otherwise(False)
            .alias("icd10_sepsis"),

        # Cancer (all malignant neoplasms C00–C97 + C7A/C7B neuroendocrine)
        pl.when(pl.col("sys").is_in(["icd10", "icd10cm"]))
            .then(
                pl.any_horizontal([
                    pl.col("dx_norm").str.contains(r"^c(0[0-9]|[1-8][0-9]|9[0-7])[a-z0-9]*$"),
                    pl.col("dx_norm").str.starts_with("c7a"),
                    pl.col("dx_norm").str.starts_with("c7b"),
                ])
            )
            .otherwise(False)
            .alias("icd10_cancer"),
    ])
)

# --- Join patient_id from hospitalization_df ---
# Assumes hospitalization_df has columns: hospitalization_id, patient_id
# This will add patient_id (and keep hospitalization_id) to each diagnosis row in hospital_dx_flags
hospital_dx_flags = hospital_dx_flags.join(
    hospitalization_df.select(["hospitalization_id", "patient_id"]),
    on="hospitalization_id",
    how="left"
)

In [None]:
# Collapse to patient level (any occurrence => True) ----
patient_cause_flags = (
    hospital_dx_flags
    .group_by("patient_id")
    .agg([
        pl.col("icd10_ischemic").any().alias("icd10_ischemic"),
        pl.col("icd10_cerebro").any().alias("icd10_cerebro"),
        pl.col("icd10_external").any().alias("icd10_external"),
        pl.col("icd10_sepsis").any().alias("icd10_sepsis"),
        pl.col("icd10_cancer").any().alias("icd10_cancer")
    ])
)

# Join flags to final_df on patient_id; fill null flags to False ----
final_df = (
    final_df
    .join(patient_cause_flags, on="patient_id", how="left")
    .with_columns([
        pl.col("icd10_ischemic").fill_null(False),
        pl.col("icd10_cerebro").fill_null(False),
        pl.col("icd10_external").fill_null(False),
        pl.col("icd10_sepsis").fill_null(False),
        pl.col("icd10_cancer").fill_null(False)
    ])
)

# Age

In [None]:
# Age < 75
relevant_cohort_with_birth = expired_hospitalizations.join(
    patient_df.select(['patient_id', 'birth_date']),
    on='patient_id',
    how='left'
)

# Calculate age at death as (discharge_dttm - birth_date) in years (using .dt.total_days()/365.25)
relevant_cohort_with_deathage = relevant_cohort_with_birth.with_columns(
    (
        (pl.col('discharge_dttm') - pl.col('birth_date')).dt.total_days() / 365.25
    ).alias('age_at_death')
)

# Create age_75_less flag per patient_id (18 <= age_at_death <= 75)
age_flag_df = (
    relevant_cohort_with_deathage
    .group_by('patient_id')
    .agg([
        (
            (pl.col('age_at_death') <= 75).any()
        ).alias('age_75_less')
    ])
)

# Join age_75_less flag onto final_df; fill nulls with False
final_df = (
    final_df
    .join(age_flag_df, on='patient_id', how='left')
    .with_columns(
        pl.col('age_75_less').fill_null(False)
    )
)

age_relevant_cohort = relevant_cohort_with_deathage.filter(
    (pl.col('age_at_death') <= 75)
)
age_relevant_cohort_n = age_relevant_cohort["patient_id"].n_unique()
strobe_counts["3_age_relevant_cohort_n"] = age_relevant_cohort_n
strobe_counts

# Labs

In [None]:
# Step 3 - Pass the potential organ quality assessment check (independent assessment) using labs from the 48 hours prior to death

# Step 3A - Kidney: creatinine <4  AND not on CRRT
# Filter out encounters who ever received crrt treatment
organ_assess_cohort = expired_hospitalizations.join(crrt_therapy.select(["hospitalization_id"]), on="hospitalization_id", how="anti")
patients_crrt_n = len(set(age_relevant_cohort["patient_id"]) - set(organ_assess_cohort["patient_id"]))

# Filter out creatinine >4
creatinine = labs_df.filter(pl.col("lab_category").is_in(["creatinine"]))

# Perform an inner join to actually bring in the matching creatinine rows/columns
relevant_creatinine = organ_assess_cohort.join(
    creatinine.select(["hospitalization_id", "lab_collect_dttm", "lab_value_numeric"]),
    on="hospitalization_id",
    how="left"
)

# Step 2: Get the latest creatinine per hospitalization
latest_creatinine = (
    relevant_creatinine
    .sort("lab_collect_dttm")  # Sort by collection time
    .group_by("hospitalization_id")
    .agg([
        pl.col("lab_collect_dttm").last().alias("latest_creatinine_dttm"),
        pl.col("lab_value_numeric").last().alias("latest_creatinine_value")
    ])
)

# Join back to organ_assess_cohort if needed
kidney_assess = (
    organ_assess_cohort
    .join(
        latest_creatinine,
        on="hospitalization_id",
        how="left"
    )
    .with_columns(
        (
            (pl.col("discharge_dttm") - pl.col("latest_creatinine_dttm")).dt.total_seconds() / 3600
        ).alias("hrs_before_discharge_latest_creatinine")
    )
)
# Filter out rows where latest_creatinine_value > 4 AND hrs_before_discharge_latest_creatinine <= 48
kidney_eligible = kidney_assess.filter(
    ~((pl.col("latest_creatinine_value") > 4) & (pl.col("hrs_before_discharge_latest_creatinine") <= 48))
)

# Capture the number of unique patients that meet the criteria and are filtered out
filtered_out_patients_count = kidney_assess.filter(
    (pl.col("latest_creatinine_value") > 4) & (pl.col("hrs_before_discharge_latest_creatinine") <= 48)
)["patient_id"].n_unique()
eligible_kidney_n = kidney_eligible["patient_id"].n_unique()


# Step 3B - Liver: Total bilirubin < 4, AST < 700, AND ALT < 700

# Filter liver enzymes and bilirubin
liver_labs = labs_df.filter(
    pl.col("lab_category").is_in(["bilirubin_total", "ast", "alt"])
)

# Bring in matching liver lab rows/columns
relevant_liver = organ_assess_cohort.join(
    liver_labs.select(["hospitalization_id", "lab_category", "lab_collect_dttm", "lab_value_numeric"]),
    on="hospitalization_id",
    how="left"
)

# Get the latest value for each of the three labs per hospitalization
latest_liver_labs = (
    relevant_liver
    .sort("lab_collect_dttm")
    .group_by(["hospitalization_id", "lab_category"])
    .agg([
        pl.col("lab_collect_dttm").last().alias("latest_lab_dttm"),
        pl.col("lab_value_numeric").last().alias("latest_lab_value")
    ])
    .pivot(
        values="latest_lab_value",
        index="hospitalization_id",
        on="lab_category"
    )
    # Renaming columns for clarity
    .rename({
        "bilirubin_total": "latest_total_bilirubin",
        "ast": "latest_ast",
        "alt": "latest_alt"
    })
)
latest_liver_lab_times = (
    relevant_liver
    .sort("lab_collect_dttm")
    .group_by(["hospitalization_id", "lab_category"])
    .agg([
        pl.col("lab_collect_dttm").last().alias("latest_lab_dttm"),
    ])
    .pivot(
        values="latest_lab_dttm",
        index="hospitalization_id",
        on="lab_category"
    )
    .rename({
        "bilirubin_total": "latest_bili_collect_dttm",
        "ast": "latest_ast_collect_dttm",
        "alt": "latest_alt_collect_dttm"
    })
)

# Now join both the value and timestamp pivots to organ_assess_cohort
liver_assess = (
    organ_assess_cohort
    .join(
        latest_liver_labs,  # contains the last values by category
        on="hospitalization_id",
        how="left"
    )
    .join(
        latest_liver_lab_times,  # contains the last lab collection times by category
        on="hospitalization_id",
        how="left"
    )
    .with_columns([
        ((pl.col("discharge_dttm") - pl.col("latest_bili_collect_dttm")).dt.total_seconds() / 3600).alias("hrs_before_discharge_latest_bili"),
        ((pl.col("discharge_dttm") - pl.col("latest_ast_collect_dttm")).dt.total_seconds() / 3600).alias("hrs_before_discharge_latest_ast"),
        ((pl.col("discharge_dttm") - pl.col("latest_alt_collect_dttm")).dt.total_seconds() / 3600).alias("hrs_before_discharge_latest_alt"),
    ])
)

liver_eligible = liver_assess.filter(
    (
        ((pl.col("latest_total_bilirubin") < 4) & (pl.col("hrs_before_discharge_latest_bili") <= 48))
        |
        ((pl.col("latest_ast") < 700) & (pl.col("hrs_before_discharge_latest_ast") <= 48))
        |
        ((pl.col("latest_alt") < 700) & (pl.col("hrs_before_discharge_latest_alt") <= 48))
    )
)

# Capture number of unique patients filtered out by each liver exclusion criteria
n_bili_gt4 = liver_assess.filter(
    (pl.col("latest_total_bilirubin") >= 4) & (pl.col("hrs_before_discharge_latest_bili") <= 48)
)["patient_id"].n_unique()

n_ast_gt700 = liver_assess.filter(
    (pl.col("latest_ast") >= 700) & (pl.col("hrs_before_discharge_latest_ast") <= 48)
)["patient_id"].n_unique()

n_alt_gt700 = liver_assess.filter(
    (pl.col("latest_alt") >= 700) & (pl.col("hrs_before_discharge_latest_alt") <= 48)
)["patient_id"].n_unique()

eligible_liver_n = liver_eligible["patient_id"].n_unique()

# Fix ShapeError: Align kidney_eligible and liver_eligible to same columns before concat
common_columns = ["hospitalization_id", "patient_id", "encounter_block", "discharge_dttm"]
overall_organ_eligible = (
    pl.concat([
        kidney_eligible.select(common_columns),
        liver_eligible.select(common_columns)
    ])
    .unique()
)
overall_organ_eligible_patients = overall_organ_eligible["patient_id"].n_unique()
strobe_counts["4_organ_eligible_patients"] = overall_organ_eligible_patients

In [None]:
# Creatinine missing count (patients in organ_assess_cohort with NO creatinine in last 48h)
creatinine_missing_patients_n = kidney_assess.filter(
    (pl.col("latest_creatinine_value").is_null()) |
    (pl.col("hrs_before_discharge_latest_creatinine") > 48)
)["patient_id"].n_unique()
print(f"Number of patients missing creatinine labs in last 48h: {creatinine_missing_patients_n}")

# Missing lab counts for liver labs in last 48 hours

bilirubin_missing_patients_n = liver_assess.filter(
    (pl.col("latest_total_bilirubin").is_null()) |
    (pl.col("hrs_before_discharge_latest_bili") > 48)
)["patient_id"].n_unique()
print(f"Number of patients missing bilirubin labs in last 48h: {bilirubin_missing_patients_n}")

ast_missing_patients_n = liver_assess.filter(
    (pl.col("latest_ast").is_null()) |
    (pl.col("hrs_before_discharge_latest_ast") > 48)
)["patient_id"].n_unique()
print(f"Number of patients missing AST labs in last 48h: {ast_missing_patients_n}")

alt_missing_patients_n = liver_assess.filter(
    (pl.col("latest_alt").is_null()) |
    (pl.col("hrs_before_discharge_latest_alt") > 48)
)["patient_id"].n_unique()
print(f"Number of patients missing ALT labs in last 48h: {alt_missing_patients_n}")

# Number of patients missing ALL liver lab values (bilirubin, ast, alt) in the last 48 hours
missing_any_liver_lab_n = liver_assess.filter(
    ((pl.col("latest_total_bilirubin").is_null()) | (pl.col("hrs_before_discharge_latest_bili") > 48)) &
    ((pl.col("latest_ast").is_null()) | (pl.col("hrs_before_discharge_latest_ast") > 48)) &
    ((pl.col("latest_alt").is_null()) | (pl.col("hrs_before_discharge_latest_alt") > 48))
)["patient_id"].n_unique()
print(f"Number of patients missing ALL bilirubin, AST, and ALT labs in last 48h: {missing_any_liver_lab_n}")


# Microbiology

In [None]:
# Filter to fluid_category == "blood_buffy" and method_category == "culture"
# Then get value counts for organism_category
# Step 4A - No positive blood cultures within 48hrs
micro_culture_filtered = micro_culture.filter(
    (pl.col("fluid_category") == "blood_buffy") & 
    (pl.col("method_category") == "culture")
)

organism_category_counts = micro_culture_filtered["organism_category"].value_counts().sort("count", descending=True)
print(organism_category_counts)

In [None]:
# Step 4A - Identify negative blood cultures and patients with no cultures in last 48h

# Filter to blood cultures (already done, but keeping for clarity)
micro_culture_filtered = micro_culture.filter(
    (pl.col("fluid_category") == "blood_buffy") & 
    (pl.col("method_category") == "culture")
)

# First, let's see what values indicate negative cultures
# Common values: "negative", "no growth", "none", or null/empty organism_category
# Check unique values to understand the data
print("Organism category value counts:")
organism_category_counts = micro_culture_filtered["organism_category"].value_counts().sort("count", descending=True)
print(organism_category_counts)

# Identify negative blood cultures
# Adjust the condition based on what you see in the value counts above
negative_cultures = micro_culture_filtered.filter(
    (
        pl.col("organism_category").str.to_lowercase().str.contains("no_growth", literal=True) |
        pl.col("organism_category").is_null() |
        (pl.col("organism_category").str.to_lowercase() == "")
    )
)

print(f"\nTotal negative blood cultures: {len(negative_cultures)}")

# Join negative cultures with the cohort (using overall_organ_eligible or age_relevant_cohort)
# You'll need to join on hospitalization_id and check if culture was within 48h of death
# Assuming you're using overall_organ_eligible as the cohort
cohort_for_micro = expired_hospitalizations  # or use age_relevant_cohort if that's your final cohort

# Join negative cultures with cohort
negative_cultures_with_cohort = cohort_for_micro.join(
    negative_cultures.select(["hospitalization_id", "collect_dttm"]),  # adjust column name if different
    on="hospitalization_id",
    how="left"
)

# Calculate hours before discharge for negative cultures
negative_cultures_with_cohort = negative_cultures_with_cohort.with_columns(
    (
        (pl.col("discharge_dttm") - pl.col("collect_dttm")).dt.total_seconds() / 3600
    ).alias("hrs_before_discharge_culture")
)

# Filter to negative cultures within 48 hours of death
negative_cultures_48h = negative_cultures_with_cohort.filter(
    (pl.col("collect_dttm").is_not_null()) &
    (pl.col("hrs_before_discharge_culture") <= 48) &
    (pl.col("hrs_before_discharge_culture") >= 0)  # Exclude future dates
)

patients_with_negative_cultures_48h = negative_cultures_48h["patient_id"].n_unique()
print(f"Number of patients with negative blood cultures in last 48h: {patients_with_negative_cultures_48h}")

# Now find patients with NO cultures at all in last 48 hours
# Join ALL blood cultures (positive and negative) with cohort
all_cultures_with_cohort = cohort_for_micro.join(
    micro_culture_filtered.select(["hospitalization_id", "collect_dttm"]),  # adjust column name if different
    on="hospitalization_id",
    how="left"
)

# Calculate hours before discharge for all cultures
all_cultures_with_cohort = all_cultures_with_cohort.with_columns(
    (
        (pl.col("discharge_dttm") - pl.col("collect_dttm")).dt.total_seconds() / 3600
    ).alias("hrs_before_discharge_culture")
)

# Find patients who have at least one culture in last 48h
patients_with_cultures_48h = all_cultures_with_cohort.filter(
    (pl.col("collect_dttm").is_not_null()) &
    (pl.col("hrs_before_discharge_culture") <= 48) &
    (pl.col("hrs_before_discharge_culture") >= 0)
)["patient_id"].unique()

# Patients with no cultures in last 48h = total cohort - patients with cultures
total_cohort_patients = cohort_for_micro["patient_id"].n_unique()
patients_with_cultures_48h_n = len(patients_with_cultures_48h)
patients_with_no_cultures_48h_n = total_cohort_patients - patients_with_cultures_48h_n

print(f"\nTotal patients in cohort: {total_cohort_patients}")
print(f"Number of patients with any blood culture in last 48h: {patients_with_cultures_48h_n}")
print(f"Number of patients with NO blood cultures in last 48h: {patients_with_no_cultures_48h_n}")

# Final cohort

In [None]:
# relevant_cohort_with_deathage.select([
#     "patient_id", 
#     "hospitalization_id", 
#     "encounter_block", 
#     "age_at_death"
# ]).unique().write_parquet("../output/intermediate/relevant_cohort_with_deathage.parquet")

In [None]:
# resp_expired_cohort_filtered_n

# If drop_counts_by_broad is a DataFrame with columns 'dx_broad' and 'count', 
# we should extract the count for each diagnosis like this:
def get_drop_count(df, broad_label):
      row = df.filter(pl.col('dx_broad') == broad_label)
      if len(row) > 0:
          return int(row['n_patients_dropped'][0])
      return 0

steps = [
    {
        'label': 'All Patients',
        'n': total_patients,
        'color': 'blue'
    },
    {
        'label': 'Deceased Patients',
        'note': '(not including Hospice)',
        'n': expired_patients_n,
        'color': 'blue',
    },
    {
        'label': 'Patients ever on IMV',
        'n': imv_expired_patients,
        'color': 'blue'
    },
    {
        'label': 'On IMV at death \n or 48hrs prior',
        'n': imv_48hr_expire,
        'color': 'blue',
        'split': [
            {
                'label': 'Patients on IMV after death',
                'n': imv_after_expire,
                'color': 'red'
            },
            {
                'label': 'Patients on IMV within\n 48hrs before or at death',
                'note': "Deceased Patients who were IMV on since 48hrs before death",
                'n': imv_48hr_expire - imv_after_expire,
                'color': 'red'
            }
        ]
    },
    {
        'label': 'Patients without\n exclusionary Dx',
        'note': 'Dx excluding sepsis, cancers and positive cultures',
        'n': resp_expired_cohort_filtered_n,
        'color': 'blue',
        'split': [
            {
                'label': 'Excluded: Sepsis',
                'n': get_drop_count(drop_counts_by_broad, 'sepsis'),
                'color': 'red'
            },
            {
                'label': 'Excluded: Cancer',
                'n': get_drop_count(drop_counts_by_broad, 'cancer'),
                'color': 'red'
            },
            {
                'label': 'Excluded: Other Dx',
                'n': get_drop_count(drop_counts_by_broad, 'other'),
                'color': 'red'
            }
        ]
    },
    {
        'label': 'Patients aged 75 or less\n at death',
        'n': age_relevant_cohort_n,
        'color': 'blue'
    },
    {
        'label': 'Patients passing one or \nboth organ quality check',
        'n': overall_organ_eligible_patients,
        'color': 'blue',
        'split': [
            {
                'label': 'Eligible Kidneys',
                'note': f" excluded on crrt {patients_crrt_n} patients followed by\n creatinine >4 {filtered_out_patients_count} patients",
                'n': eligible_kidney_n,
                'color': 'red'
            },
            {
                'label': 'Eligible Livers',
                'note': f"excluded \n bilirubin > 4 patients: {n_bili_gt4} | AST> 700 patients: {n_ast_gt700} | nALT>700 patients: {n_alt_gt700} ",
                'n': eligible_liver_n,
                'color': 'red'
            }
        ]
    },
]

fig = create_consort_diagram(
    steps,
    title="COHORT SELECTION: Potential Organ Donors"
)
fig.savefig("../output/final/cohort_strobe.png", bbox_inches="tight", dpi=300)
plt.show()