# Potential Organ Donor Identifier

# Setup

In [None]:
import sys 
import os
import polars as pl 
import matplotlib.pyplot as plt
import pandas as pd

from utils.config import config
from utils.io import read_data
from utils.strobe_diagram import create_consort_diagram
from clifpy.utils.stitching_encounters import stitch_encounters
from utils.outlier_handler import apply_outlier_handling
import gc

In [None]:
site_name = config['site_name']
tables_path = config['tables_path']
file_type = config['file_type']
project_root = config['project_root']
sys.path.insert(0, project_root)
print(f"Site Name: {site_name}")
print(f"Tables Path: {tables_path}")
print(f"File Type: {file_type}")
from pathlib import Path
PROJECT_ROOT = Path(config['project_root'])
UTILS_DIR = PROJECT_ROOT / "utils"
OUTPUT_DIR = PROJECT_ROOT / "output"
OUTPUT_FINAL_DIR = OUTPUT_DIR / "final"
OUTPUT_INTERMEDIATE_DIR = OUTPUT_DIR / "intermediate"

In [None]:
strobe_counts = {}

# Load data

In [None]:
# read required tables
adt_filepath = f"{tables_path}/clif_adt.{file_type}"
hospitalization_filepath = f"{tables_path}/clif_hospitalization.{file_type}"
patient_filepath = f"{tables_path}/clif_patient.{file_type}"
adt_df = read_data(adt_filepath, file_type)
hospitalization_df = read_data(hospitalization_filepath, file_type)
patient_df = read_data(patient_filepath, file_type)

In [None]:
total_patients = patient_df["patient_id"].n_unique()
strobe_counts["0_all_patients"] = total_patients
strobe_counts

# Identify decedents

In [None]:
all_decedents_df = hospitalization_df.filter(
    pl.col('discharge_category').str.to_lowercase() == 'expired'
)

In [None]:
all_decedent_patient_ids = all_decedents_df.select('patient_id').to_series().to_list()
all_decedent_hosp_ids = all_decedents_df.select('hospitalization_id').to_series().to_list()

# Stitch Encounters

In [None]:
# Check if hospitalization_df has duplicate patient_id, hospitalization_id pairs
if hospitalization_df.shape[0] != hospitalization_df.unique(subset=["patient_id", "hospitalization_id"]).shape[0]:
    print("Warning: hospitalization_df contains duplicate (patient_id, hospitalization_id) rows.")

# Check if adt_df has duplicate patient_id, hospitalization_id, adt_event_id (or similar) triplets
# If adt_df has an event or unique identifier column, replace 'adt_event_id' with correct column
adt_unique_cols = [col for col in [ "hospitalization_id", "in_dttm"] if col in adt_df.columns]
if len(adt_unique_cols) >= 2 and adt_df.shape[0] != adt_df.unique(subset=adt_unique_cols).shape[0]:
    print("Warning: adt_df contains duplicate rows for identifier columns:", adt_unique_cols)

In [None]:
# Filter hospitalization_df and adt_df down to all_decedent_hosp_ids
hospitalization_df_subset = hospitalization_df.filter(
    pl.col("patient_id").is_in(all_decedent_patient_ids)
)
all_decedent_hosp_ids = hospitalization_df_subset.select('hospitalization_id').to_series().to_list()
adt_df_subset = adt_df.filter(
    pl.col("hospitalization_id").is_in(all_decedent_hosp_ids)
)

In [None]:
hosp_stitched, adt_stitched, encounter_mapping = stitch_encounters(
      hospitalization=hospitalization_df_subset.to_pandas(),
      adt=adt_df_subset.to_pandas(),
      time_interval=12
  )

hosp_stitched = pl.from_pandas(hosp_stitched)
adt_stitched = pl.from_pandas(adt_stitched)

# Ensure encounter_block is int32 (if present)
hosp_stitched = hosp_stitched.with_columns(
    pl.col("encounter_block").cast(pl.Int32)
)
adt_stitched = adt_stitched.with_columns(
    pl.col("encounter_block").cast(pl.Int32)
)

# Filter hosp_stitched to only hospitalizations that are present in the adt_stitched table
# assuming 'hospitalization_id' is the matching key
# (if there are multiple relevant keys, adjust accordingly)
if "hospitalization_id" in hosp_stitched.columns and "hospitalization_id" in adt_stitched.columns:
    hosp_stitched = (
        hosp_stitched.filter(
            pl.col("hospitalization_id").is_in(adt_stitched["hospitalization_id"].unique())
        )
    )
encounter_mapping = pl.from_pandas(encounter_mapping)
encounter_mapping = encounter_mapping.with_columns(
    pl.col("encounter_block").cast(pl.Int32)
)

gc.collect()

In [None]:
# Identify expired encounters
decedents_df = hosp_stitched.filter(
    pl.col('discharge_category').str.to_lowercase() == 'expired'
)

# Make hospitalization subset for expired
# Join patient_id and death_dttm from patient_df to final_df

final_df = (
    decedents_df
    .select([
        'patient_id',
        'hospitalization_id',
        'encounter_block',
        'admission_dttm',
        'discharge_dttm', # discharge datetime for the death hospitalization
        "age_at_admission", 
        "discharge_category",
        "admission_type_category"
    ])
    .with_columns([
        pl.col("discharge_category").str.to_lowercase(),
        pl.col("admission_type_category").str.to_lowercase()
    ])
    .unique()
)

# Now join patient_id and death_dttm from patient_df to final_df
demog_cols = ['patient_id', 'death_dttm', 'race_category', 'sex_category','ethnicity_category' ]
final_df = final_df.join(
    patient_df.select(demog_cols), on='patient_id', how='left'
)

decedents_df_n = final_df["patient_id"].n_unique()
strobe_counts["1_decedents_df_n"] = decedents_df_n
strobe_counts

# Final outcome dttm

In [None]:
vitals_filepath = f"{tables_path}/clif_vitals.{file_type}"
vitals_df = read_data(
      vitals_filepath,
      file_type,
      filter_ids=all_decedent_hosp_ids,
      id_column='hospitalization_id'
  )

In [None]:
vitals_df = apply_outlier_handling(vitals_df, 'vitals')

In [None]:
# First, sort by recorded_dttm within each hospitalization_id
vitals_df = (
    vitals_df
    .sort(['hospitalization_id', 'recorded_dttm'])
)


# Get first and last recorded_dttm, plus last weight and height for each hospitalization
vitals_first_last = (
    vitals_df
    .group_by('hospitalization_id')
    .agg([
        pl.col('recorded_dttm').min().alias('first_recorded_vital_dttm'),
        pl.col('recorded_dttm').max().alias('last_recorded_vital_dttm'),

        # Get last recorded weight
        pl.col('vital_value')
            .filter(pl.col('vital_category') == 'weight_kg')
            .last()
            .alias('last_weight_kg'),

        # Get last recorded height
        pl.col('vital_value')
            .filter(pl.col('vital_category') == 'height_cm')
            .last()
            .alias('last_height_cm')
    ])
)

# Calculate BMI
vitals_first_last = vitals_first_last.with_columns(
    (pl.col('last_weight_kg') / ((pl.col('last_height_cm') / 100) ** 2)).alias('bmi')
)

# Join with final_df
final_df = final_df.join(vitals_first_last, on='hospitalization_id', how='left')

# Define final_death_dttm as death_dttm, if missing then last_recorded_vital_dttm
final_df = final_df.with_columns(
    pl.when(pl.col("death_dttm").is_not_null())
      .then(pl.col("death_dttm"))
      .otherwise(pl.col("last_recorded_vital_dttm"))
      .alias("final_death_dttm")
)

# Inpatient decedents

Identify inpatient encounters - location must be ed, ward, stepdown, icu at last_recorded_vital_dttm

In [None]:
eligible_locations = ['ed', 'ward', 'stepdown', 'icu']

In [None]:
# Check that all decedents are present in ADT table
decedent_hosp_in_adt = set(adt_df.select('hospitalization_id').to_series().to_list())
missing_in_adt = set(all_decedent_hosp_ids) - decedent_hosp_in_adt

if missing_in_adt:
    print(f"Warning: {len(missing_in_adt)} hospitalization(s) missing in ADT table")
    print(f"Missing hospitalization_ids: {missing_in_adt}")
else:
    print(f"✓ All {len(all_decedent_hosp_ids)} decedent hospitalizations present in ADT table")

In [None]:
last_location_per_hosp = (
      adt_df
      .filter(pl.col('hospitalization_id').is_in(all_decedent_hosp_ids))
      .sort('out_dttm', descending=True)
      .group_by('hospitalization_id')
      .agg([
          pl.col('location_category').first().alias('last_location_category'),
          pl.col('location_name').first().alias('last_location_name'),
          pl.col('out_dttm').first().alias('last_location_out_dttm'),
          (pl.col('location_category').str.to_lowercase() == 'icu').any().alias('ever_icu'),
          (pl.col('location_category').str.to_lowercase() == 'ward').any().alias('ever_ward'),
          (pl.col('location_category').str.to_lowercase() == 'ed').any().alias('ever_ed'),
          (pl.col('location_category').str.to_lowercase() == 'stepdown').any().alias('ever_stepdown'),
          pl.col('location_category').unique().sort().alias('all_locations')
      ])
  )

In [None]:
final_df = final_df.join(
    last_location_per_hosp,
    on='hospitalization_id',
    how='left'
)

In [None]:
# Identify the number of hospitalizations where ever_icu, ever_ward, ever_ed, ever_stepdown are all False or null
n_all_locs_false = final_df.filter(
    ~(pl.col('ever_icu').fill_null(False) | 
      pl.col('ever_ward').fill_null(False) | 
      pl.col('ever_ed').fill_null(False) | 
      pl.col('ever_stepdown').fill_null(False))
).height
print(f"Number of hospitalizations where all four location flags are False/null: {n_all_locs_false}")

# Create final_cohort_df dropping those hospitalizations
final_cohort_df = final_df.filter(
    (pl.col('ever_icu').fill_null(False) | 
     pl.col('ever_ward').fill_null(False) | 
     pl.col('ever_ed').fill_null(False) | 
     pl.col('ever_stepdown').fill_null(False))
)

In [None]:
all_decedent_inpatient_patient_ids = final_cohort_df.select('patient_id').to_series().to_list()
all_decedent_inpatient_hosp_ids = final_cohort_df.select('hospitalization_id').to_series().to_list()
strobe_counts["2_inpatient_decedents"] = len(all_decedent_inpatient_patient_ids)
strobe_counts

In [None]:
adt_stitched.columns

# ADT

In [None]:
# Calculate hospital and ICU length of stay using approach similar to the provided reference (adapted for Polars)

# Filter adt_df to only the relevant hospitalizations
adt_in_cohort = adt_stitched.filter(pl.col("hospitalization_id").is_in(all_decedent_inpatient_hosp_ids))

# Lowercase location_category (just the column, not the whole DataFrame)
adt_in_cohort = adt_in_cohort.with_columns(
    pl.col("location_category").str.to_lowercase().alias("location_category")
)

# Hospital admission summary per encounter_block: first in and last out, first admission location
hosp_admission_summary = (
    adt_in_cohort
    .group_by("encounter_block")
    .agg([
        pl.col("in_dttm").min().alias("min_in_dttm"),
        pl.col("out_dttm").max().alias("max_out_dttm"),
        pl.col("location_category").first().alias("first_admission_location")
    ])
    .with_columns([
        ((pl.col("max_out_dttm") - pl.col("min_in_dttm")).dt.total_days()).alias("hospital_length_of_stay_days")
    ])
)

# Join first_admission_location and hospital_length_of_stay_days to final_cohort_df on encounter_block
final_cohort_df = final_cohort_df.join(
    hosp_admission_summary.select([
        "encounter_block", 
        "first_admission_location", 
        "hospital_length_of_stay_days"
    ]),
    on="encounter_block",
    how="left"
)

# Restrict to ICU stays only
icu_df = adt_in_cohort.filter(pl.col("location_category") == "icu")

# Find first ICU admission per encounter_block
first_icu_in = (
    icu_df
    .group_by("encounter_block")
    .agg(pl.col("in_dttm").min().alias("first_icu_in_dttm"))
)

# Join back to get corresponding out_dttm for the first ICU in_dttm
icu_summary = (
    first_icu_in.join(
        icu_df.select(["encounter_block", "in_dttm", "out_dttm"]),
        left_on=["encounter_block", "first_icu_in_dttm"],
        right_on=["encounter_block", "in_dttm"],
        how="left"
    )
    .with_columns([
        pl.col("out_dttm").alias("first_icu_out_dttm"),
        ((pl.col("out_dttm") - pl.col("first_icu_in_dttm")).dt.total_seconds() / (3600*24)).alias("first_icu_los_days")
    ])
    .select([
        "encounter_block", "first_icu_in_dttm", "first_icu_out_dttm", "first_icu_los_days"
    ])
)

final_cohort_df = final_cohort_df.join(
    icu_summary.select([
        "encounter_block", 
        "first_icu_los_days"
    ]),
    on="encounter_block",
    how="left"
)

# Now, hosp_admission_summary contains hospital LOS and first_admission_location, and icu_summary contains first ICU LOS


# Age

In [None]:
# Age < 75
final_cohort_df = final_cohort_df.join(
    patient_df.select(['patient_id', 'birth_date']),
    on='patient_id',
    how='left'
)

# Calculate age at death as (discharge_dttm - birth_date) in years (using .dt.total_days()/365.25)
final_cohort_df = final_cohort_df.with_columns(
    (
        (pl.col('final_death_dttm') - pl.col('birth_date')).dt.total_days() / 365.25
    ).alias('age_at_death')
)

# Create age_75_less flag per patient_id ( age_at_death <= 75)
age_flag_df = (
    final_cohort_df
    .group_by('patient_id')
    .agg([
        (
            (pl.col('age_at_death') <= 75).any()
        ).alias('age_75_less')
    ])
)

# Join age_75_less flag onto final_df; fill nulls with False
final_cohort_df = (
    final_cohort_df
    .join(age_flag_df, on='patient_id', how='left')
    .with_columns(
        pl.col('age_75_less').fill_null(False)
    )
)

# Filter age < 75 using the flag, not the missing column
age_relevant_cohort = final_cohort_df.filter(
    pl.col('age_75_less') == True
)
age_relevant_cohort_n = age_relevant_cohort["patient_id"].n_unique()
strobe_counts["3_age_relevant_cohort_n"] = age_relevant_cohort_n
strobe_counts

# ICD Codes

The CALC criteria includes the following as cause:
- I20–I25: ischemic heart disease
- I60–I69: cerebrovascular disease
- V01–Y89: external causes (e.g., blunt trauma, gunshot wounds, overdose, suicide, drowning, asphyxiation)

[Reference](https://www.cms.gov/files/document/112020-opo-final-rule-cms-3380-f.pdf)

We also flag contraindications of sepsis and cancer using ICD10 codes. We use the ICD codes for these specified in utils/icd10_contraindications.csv

In [None]:
# hospial_dx_filepath = f"{tables_path}/clif_hospital_diagnosis.{file_type}"
# hospital_dx = read_data(
#     hospial_dx_filepath,
#     file_type,
#     filter_ids=all_decedent_inpatient_hosp_ids,
#     id_column='hospitalization_id'
# )

# # Join on hospitalization_id to add patient_id from final_cohort_df to hospital_dx
# hospital_dx = (
#     hospital_dx.join(
#         final_cohort_df.select(['hospitalization_id', 'patient_id']),
#         on='hospitalization_id',
#         how='left'
#     )
# )

# # Show how many hosp ids from all_decedent_inpatient_hosp_ids are present in hospital_dx
# present_hosp_ids = set(hospital_dx['hospitalization_id'].unique())
# requested_hosp_ids = set(all_decedent_inpatient_hosp_ids)
# n_present = len(present_hosp_ids & requested_hosp_ids)
# n_requested = len(requested_hosp_ids)
# print(f"Hospitalization IDs present in hospital_dx: {n_present} out of {n_requested}")

# # Add these counts to strobe_counts
# strobe_counts["5_present_inpatient_hospitalization_ids_in_hospital_dx"] = n_present

# # Add count: how many age_relevant_cohort patients are present in hospital_dx
# age_relevant_patient_ids = set(age_relevant_cohort['patient_id'].unique())
# hospital_dx_patient_ids = set(hospital_dx['patient_id'].unique())
# n_age_relevant_in_hospital_dx = len(age_relevant_patient_ids & hospital_dx_patient_ids)
# strobe_counts["5_age_relevant_in_hospital_dx"] = n_age_relevant_in_hospital_dx

# strobe_counts



# SKIP the join here - it crashes with large data
# patient_id will be added later at line 514 via hospitalization_df join

# SKIP the counts - they cause .unique() crashes
# These are just for reporting, not essential for analysis
print("Skipping hospitalization ID counts to avoid kernel crash")
print("(Counts will be skipped in strobe_counts)")

# Continue to ICD code processing below...

In [None]:
import duckdb
import pandas as pd
import polars as pl
# Load hospital_dx WITHOUT any join (avoid crash)
hospial_dx_filepath = f"{tables_path}/clif_hospital_diagnosis.{file_type}"
hospital_dx = read_data(
    hospial_dx_filepath,
    file_type,
    filter_ids=all_decedent_inpatient_hosp_ids,
    id_column='hospitalization_id'
)

print(f"✓ Loaded hospital_dx: {len(hospital_dx)} rows")

# ---- Load contraindications ----
contraindications_df = pl.read_csv(str(UTILS_DIR / "icd10_contraindications.csv"))
contraindication_codes = (
    contraindications_df
    .with_columns([
        pl.col("ICD-10-CM")
            .cast(pl.Utf8)
            .str.to_lowercase()
            .str.replace_all(r"[.\s]", "")
            .alias("code_norm")
    ])
    .select("code_norm")
    .to_series()
    .to_list()
)

print(f"Loaded {len(contraindication_codes)} contraindication codes")

contraindication_codes_df = pd.DataFrame({'code': contraindication_codes})
all_ids_df = pd.DataFrame({'hospitalization_id': list(all_decedent_inpatient_hosp_ids)})

query = f"""
WITH hospital_dx_normalized AS (
    SELECT 
        *,
        LOWER(REGEXP_REPLACE(diagnosis_code, '[.\\s]', '', 'g')) AS dx_norm,
        LOWER(diagnosis_code_format) AS sys
    FROM read_parquet('{hospial_dx_filepath}')
    WHERE hospitalization_id IN (SELECT hospitalization_id FROM all_ids_df)
),
hospital_dx_flags AS (
    SELECT
        hospitalization_id,
        CASE 
            WHEN sys IN ('icd10', 'icd10cm') AND REGEXP_MATCHES(dx_norm, '^i2[0-5]\\w*$') THEN true 
            ELSE false 
        END AS icd10_ischemic,
        CASE 
            WHEN sys IN ('icd10', 'icd10cm') AND REGEXP_MATCHES(dx_norm, '^i6[0-9]\\w*$') THEN true 
            ELSE false 
        END AS icd10_cerebro,
        CASE 
            WHEN sys IN ('icd10', 'icd10cm') AND REGEXP_MATCHES(dx_norm, '^(v0[1-9]|v[1-9]\\d|w\\d{{2}}|x\\d{{2}}|y[0-8]\\d)\\w*$') THEN true 
            ELSE false 
        END AS icd10_external,
        CASE 
            WHEN sys IN ('icd10', 'icd10cm') AND dx_norm IN (SELECT code FROM contraindication_codes_df) THEN true 
            ELSE false 
        END AS icd10_contraindication
    FROM hospital_dx_normalized
),
hospital_dx_with_patient AS (
    SELECT h.*, hosp.patient_id
    FROM hospital_dx_flags h
    LEFT JOIN hospitalization_df hosp ON h.hospitalization_id = hosp.hospitalization_id
)
SELECT
    patient_id,
    BOOL_OR(icd10_ischemic) AS icd10_ischemic,
    BOOL_OR(icd10_cerebro) AS icd10_cerebro,
    BOOL_OR(icd10_external) AS icd10_external,
    BOOL_OR(icd10_contraindication) AS icd10_contraindication
FROM hospital_dx_with_patient
WHERE patient_id IS NOT NULL
GROUP BY patient_id
"""

print("Processing ICD flags with DuckDB (FIXED REGEX)...")
patient_cause_flags_pd = duckdb.sql(query).df()
patient_cause_flags = pl.from_pandas(patient_cause_flags_pd)

print(f"✓ Processed {len(patient_cause_flags)} patients")
print(f"\nFlag counts:")
print(f"  icd10_ischemic: {patient_cause_flags['icd10_ischemic'].sum()}")
print(f"  icd10_cerebro: {patient_cause_flags['icd10_cerebro'].sum()}")
print(f"  icd10_external: {patient_cause_flags['icd10_external'].sum()}")
print(f"  icd10_contraindication: {patient_cause_flags['icd10_contraindication'].sum()}")

In [None]:
# Join flags to final_df on patient_id; fill null flags to False ----
final_cohort_df = (
    final_cohort_df
    .join(patient_cause_flags, on="patient_id", how="left")
    .with_columns([
        pl.col("icd10_ischemic").fill_null(False),
        pl.col("icd10_cerebro").fill_null(False),
        pl.col("icd10_external").fill_null(False),
        pl.col("icd10_contraindication").fill_null(False),
    ])
)

# Count patients with any of: ischemic OR cerebrovascular OR external cause (CALC cause, no age/location applied)
calc_cause_n = final_cohort_df.filter(
    pl.col("icd10_ischemic") | pl.col("icd10_cerebro") | pl.col("icd10_external")
)["patient_id"].n_unique()
strobe_counts["calc_cause"] = calc_cause_n

# Count patients with calc_cause (any cause) AND no contraindications
calc_cause_no_contraindication_n = final_cohort_df.filter(
    (pl.col("icd10_ischemic") | pl.col("icd10_cerebro") | pl.col("icd10_external")) & ~pl.col("icd10_contraindication")
)["patient_id"].n_unique()
strobe_counts["calc_cause_no_contraindication"] = calc_cause_no_contraindication_n

# CALC Criteria

CMS adopts the Cause, Age, and Location-consistent (CALC) method to define “death consistent with organ donation” for donor-potential calculations:

- **Age**: deaths ≤75 years
- **Location**: inpatient deaths (death occurs in the hospital)
- **Cause** (ICD-10-CM, inclusion ranges):
    - I20–I25: ischemic heart disease
    - I60–I69: cerebrovascular disease
    - V01–Y89: external causes (e.g., blunt trauma, gunshot wounds, overdose, suicide, drowning, asphyxiation)


[Reference](https://www.cms.gov/files/document/112020-opo-final-rule-cms-3380-f.pdf)

In [None]:
final_cohort_df = final_cohort_df.with_columns(
    (
        (pl.col('age_75_less')) &
        (pl.col('icd10_ischemic') | pl.col('icd10_cerebro') | pl.col('icd10_external')) &
        (~pl.col('icd10_contraindication'))
    ).alias('calc_flag')
)

# Count for STROBE tracking
calc_qualified_n = final_cohort_df.filter(pl.col('calc_flag'))['patient_id'].n_unique()
strobe_counts["calc_qualified"] = calc_qualified_n

print(f"\nCALC flag qualified: {calc_qualified_n} patients")

In [None]:
test_case_fixed = f"""
WITH test_data AS (
    SELECT 
        diagnosis_code,
        LOWER(REGEXP_REPLACE(diagnosis_code, '[.\\s]', '', 'g')) AS dx_norm,
        LOWER(diagnosis_code_format) AS sys
    FROM read_parquet('{hospial_dx_filepath}')
    WHERE diagnosis_code LIKE 'I2%'
    LIMIT 20
)
SELECT
    diagnosis_code,
    dx_norm,
    sys,
    CASE 
        WHEN sys IN ('icd10', 'icd10cm') AND REGEXP_MATCHES(dx_norm, '^i2[0-5]\\w*$') THEN true 
        ELSE false 
    END AS icd10_ischemic_flag
FROM test_data
"""

test_fixed = duckdb.sql(test_case_fixed).df()
print(test_fixed)
print(f"\nTrue count: {test_fixed['icd10_ischemic_flag'].sum()}")

In [None]:
strobe_counts

# IMV

In [None]:
resp_filepath = f"{tables_path}/clif_respiratory_support.{file_type}"
resp_df = read_data(
      resp_filepath,
      file_type,
      filter_ids=all_decedent_inpatient_hosp_ids,
      id_column='hospitalization_id'
  )

In [None]:
# IMV - use DuckDB to avoid loading large file into Polars
print("Processing IMV data with DuckDB...")

# Create temp dataframe with needed columns from final_cohort
final_cohort_for_imv = final_cohort_df.select([
    "hospitalization_id", "patient_id", "encounter_block", "final_death_dttm"
]).to_pandas()

imv_query = f"""
WITH imv_data AS (
    SELECT 
        hospitalization_id,
        recorded_dttm,
        device_category
    FROM read_parquet('{tables_path}/clif_respiratory_support.{file_type}')
    WHERE LOWER(device_category) = 'imv'
        AND hospitalization_id IN (SELECT hospitalization_id FROM final_cohort_for_imv)
),
imv_with_death AS (
    SELECT 
        i.hospitalization_id,
        i.recorded_dttm,
        f.patient_id,
        f.encounter_block,
        f.final_death_dttm,
        EXTRACT(EPOCH FROM (f.final_death_dttm - i.recorded_dttm)) / 3600 AS hr_2death_last_imv
    FROM imv_data i
    INNER JOIN final_cohort_for_imv f ON i.hospitalization_id = f.hospitalization_id
),
-- CHANGED: Get latest record FIRST from ALL records (not just within 48h)
latest_imv_per_patient AS (
    SELECT 
        patient_id,
        hospitalization_id,
        encounter_block,
        final_death_dttm,
        recorded_dttm,
        hr_2death_last_imv,
        ROW_NUMBER() OVER (
            PARTITION BY patient_id 
            ORDER BY recorded_dttm DESC, hospitalization_id ASC
        ) AS rn
    FROM imv_with_death  -- ← Using ALL records, not pre-filtered
)
-- CHANGED: Apply time window filter AFTER selecting latest
SELECT 
    patient_id,
    hospitalization_id,
    encounter_block,
    final_death_dttm,
    recorded_dttm,
    hr_2death_last_imv
FROM latest_imv_per_patient
WHERE rn = 1  -- Get the latest record first
    AND hr_2death_last_imv <= 48   -- Then filter to time window
    AND hr_2death_last_imv >= -24
"""

resp_expired_cohort = duckdb.sql(imv_query).df()
resp_expired_cohort = pl.from_pandas(resp_expired_cohort)

imv_48hr_expire = resp_expired_cohort["patient_id"].n_unique()
print(f"✓ Patients on IMV within 48h of death: {imv_48hr_expire}")

strobe_counts["6_imv_48hr_expire"] = imv_48hr_expire

# Create flag
imv_48hr_expire_patients = resp_expired_cohort.select(["patient_id"]).unique()
imv_48hr_expire_patients = imv_48hr_expire_patients.with_columns(
    pl.lit(True).alias("imv_48hr_expire")
)

final_cohort_df = final_cohort_df.join(imv_48hr_expire_patients, on="patient_id", how="left")
final_cohort_df = final_cohort_df.with_columns(
    pl.col("imv_48hr_expire").fill_null(False)
)

print(f"✓ IMV processing complete")

In [None]:
final_cohort_df = final_cohort_df.join(imv_48hr_expire_patients, on="patient_id", how="left")
final_cohort_df = final_cohort_df.with_columns(
    pl.col("imv_48hr_expire").fill_null(False))

# Organ quality check

Pass the potential organ quality assessment check (independent assessment) using last recorded lab values, as defined by CMS
* Kidney: recorded creatinine, cr  <4  AND not on CRRT
* Liver: recorded TB, AST, ALT and Total bilirubin < 4, AST < 700, AND ALT< 700
* BMI <=50

In [None]:
# crrt_filepath = f"{tables_path}/clif_crrt_therapy.{file_type}"
# labs_filepath = f"{tables_path}/clif_labs.{file_type}"
# crrt_therapy = read_data(
#       crrt_filepath,
#       file_type,
#       filter_ids=all_decedent_inpatient_hosp_ids,
#       id_column='hospitalization_id'
#   )

# labs_df = read_data(
#       labs_filepath,
#       file_type,
#       filter_ids=all_decedent_inpatient_hosp_ids,
#       id_column='hospitalization_id'
#   )

# labs_df = apply_outlier_handling(labs_df, 'labs')

# # ============================================
# # Prepare final cohort with timing info
# # ============================================
# final_cohort_for_labs = final_cohort_df.select([
#     "patient_id",
#     "hospitalization_id",
#     "final_death_dttm",
# ])

# # ============================================
# # Filter labs to those recorded BEFORE final_death_dttm
# # ============================================
# labs_before_last_vital = (
#     labs_df
#     .join(
#         final_cohort_for_labs,
#         on='hospitalization_id',
#         how='inner'
#     )
#     .filter(pl.col('lab_collect_dttm') <= pl.col('final_death_dttm'))
# )

# # ============================================
# # Get CREATININE - last value before last vital
# # ============================================
# creatinine_labs = labs_before_last_vital.filter(
#     pl.col('lab_category') == 'creatinine'
# )

# latest_creatinine = (
#     creatinine_labs
#     .sort('lab_collect_dttm')
#     .group_by('hospitalization_id')
#     .agg([
#         pl.col('lab_collect_dttm').last().alias('creatinine_dttm'),
#         pl.col('lab_value_numeric').last().alias('creatinine_value')
#     ])
# )

# # ============================================
# # Get LIVER LABS - last values before last vital
# # ============================================
# liver_labs_categories = labs_before_last_vital.filter(
#     pl.col('lab_category').is_in(['bilirubin_total', 'ast', 'alt'])
# )

# # Get values pivoted by category
# latest_liver_values = (
#     liver_labs_categories
#     .sort('lab_collect_dttm')
#     .group_by(['hospitalization_id', 'lab_category'])
#     .agg(pl.col('lab_value_numeric').last().alias('lab_value'))
#     .pivot(
#         values='lab_value',
#         index='hospitalization_id',
#         on='lab_category'
#     )
#     .rename({
#         'bilirubin_total': 'bilirubin_total_value',
#         'ast': 'ast_value',
#         'alt': 'alt_value'
#     })
# )

# # Get collection datetimes for each lab
# latest_liver_datetimes = (
#     liver_labs_categories
#     .sort('lab_collect_dttm')
#     .group_by(['hospitalization_id', 'lab_category'])
#     .agg(pl.col('lab_collect_dttm').last().alias('lab_dttm'))
#     .pivot(
#         values='lab_dttm',
#         index='hospitalization_id',
#         on='lab_category'
#     )
#     .rename({
#         'bilirubin_total': 'bilirubin_total_dttm',
#         'ast': 'ast_dttm',
#         'alt': 'alt_dttm'
#     })
# )

# # ============================================
# # Combine all lab values into one dataframe
# # ============================================
# organ_labs = (
#     final_cohort_for_labs
#     .join(latest_creatinine, on='hospitalization_id', how='left')
#     .join(latest_liver_values, on='hospitalization_id', how='left')
#     .join(latest_liver_datetimes, on='hospitalization_id', how='left')
#     .select([
#         'patient_id',
#         'creatinine_value',
#         'creatinine_dttm',
#         'bilirubin_total_value',
#         'bilirubin_total_dttm',
#         'ast_value',
#         'ast_dttm',
#         'alt_value',
#         'alt_dttm'
#     ])
# )

# # print(organ_labs.head())
# print(f"\nOrgan labs summary:")
# print(f"  Patients with creatinine: {organ_labs.filter(pl.col('creatinine_value').is_not_null())['patient_id'].n_unique()}")
# print(f"  Patients with bilirubin: {organ_labs.filter(pl.col('bilirubin_total_value').is_not_null())['patient_id'].n_unique()}")
# print(f"  Patients with AST: {organ_labs.filter(pl.col('ast_value').is_not_null())['patient_id'].n_unique()}")
# print(f"  Patients with ALT: {organ_labs.filter(pl.col('alt_value').is_not_null())['patient_id'].n_unique()}")

# # ============================================
# # Check for CRRT within 48 hours before death
# # ============================================

# # Join CRRT with final cohort to get final_death_dttm
# crrt_with_death_time = (
#     crrt_therapy
#     .join(
#         final_cohort_df.select(['hospitalization_id', 'final_death_dttm']),
#         on='hospitalization_id',
#         how='inner'
#     )
# )

# # Filter to CRRT recorded before death
# crrt_before_death = crrt_with_death_time.filter(
#     pl.col('recorded_dttm') <= pl.col('final_death_dttm')
# )

# # Check if within 48 hours of death
# crrt_48h_before_death = (
#     crrt_before_death
#     .with_columns(
#         (
#             (pl.col('final_death_dttm') - pl.col('recorded_dttm')).dt.total_seconds() / 3600
#         ).alias('hrs_before_death')
#     )
#     .filter(
#         (pl.col('hrs_before_death') <= 48) & 
#         (pl.col('hrs_before_death') >= 0)
#     )
# )

# # Create flag: any CRRT within 48h of death
# on_crrt_flag = (
#     crrt_48h_before_death
#     .select('hospitalization_id')
#     .unique()
#     .with_columns(pl.lit(True).alias('on_crrt_48h_before_death'))
# )

# # Join flag to final_cohort_df
# final_cohort_df = final_cohort_df.join(
#     on_crrt_flag,
#     on='hospitalization_id',
#     how='left'
# )

# # Fill nulls with False
# final_cohort_df = final_cohort_df.with_columns(
#     pl.col('on_crrt_48h_before_death').fill_null(False)
# )

# # Count for tracking
# on_crrt_n = final_cohort_df.filter(pl.col('on_crrt_48h_before_death'))['patient_id'].n_unique()
# print(f"Patients on CRRT within 48h before death: {on_crrt_n}")

In [None]:
import duckdb
crrt_filepath = f"{tables_path}/clif_crrt_therapy.{file_type}"
# CRRT - use DuckDB to avoid loading large file into Polars
print("Processing CRRT data with DuckDB...")

# Create temp dataframe with needed columns from final_cohort
final_cohort_for_crrt = final_cohort_df.select([
    "hospitalization_id", "final_death_dttm"
]).to_pandas()

crrt_query = f"""
WITH crrt_data AS (
    SELECT 
        hospitalization_id,
        recorded_dttm
    FROM read_parquet('{crrt_filepath}')
    WHERE hospitalization_id IN (SELECT hospitalization_id FROM final_cohort_for_crrt)
),
crrt_with_death AS (
    SELECT 
        c.hospitalization_id,
        c.recorded_dttm,
        f.final_death_dttm,
        EXTRACT(EPOCH FROM (f.final_death_dttm - c.recorded_dttm)) / 3600 AS hrs_before_death
    FROM crrt_data c
    INNER JOIN final_cohort_for_crrt f ON c.hospitalization_id = f.hospitalization_id
    WHERE c.recorded_dttm <= f.final_death_dttm
),
crrt_within_48h AS (
    SELECT 
        hospitalization_id
    FROM crrt_with_death
    WHERE hrs_before_death <= 48 AND hrs_before_death >= 0
)
SELECT DISTINCT hospitalization_id
FROM crrt_within_48h
"""

crrt_48h_result = duckdb.sql(crrt_query).df()
crrt_48h_result = pl.from_pandas(crrt_48h_result)

# Create flag
on_crrt_flag = crrt_48h_result.with_columns(
    pl.lit(True).alias('on_crrt_48h_before_death')
)

# Join flag to final_cohort_df
final_cohort_df = final_cohort_df.join(
    on_crrt_flag,
    on='hospitalization_id',
    how='left'
)

# Fill nulls with False
final_cohort_df = final_cohort_df.with_columns(
    pl.col('on_crrt_48h_before_death').fill_null(False)
)

# Count for tracking
on_crrt_n = final_cohort_df.filter(pl.col('on_crrt_48h_before_death'))['patient_id'].n_unique()
print(f"✓ Patients on CRRT within 48h before death: {on_crrt_n}")
print(f"✓ CRRT processing complete")

<!-- Organ labs summary:
  Patients with creatinine: 6188
  Patients with bilirubin: 5960
  Patients with AST: 5912
  Patients with ALT: 5948 -->

In [None]:
labs_filepath = f"{tables_path}/clif_labs.{file_type}"
# Labs - use DuckDB to avoid loading large file into Polars
print("Processing Labs data with DuckDB...")

# Create temp dataframe with needed columns from final_cohort
final_cohort_for_labs = final_cohort_df.select([
    "patient_id",
    "hospitalization_id",
    "final_death_dttm",
]).to_pandas()

labs_query = f"""
WITH labs_data AS (
    SELECT 
        hospitalization_id,
        lab_collect_dttm,
        lab_category,
        lab_value_numeric
    FROM read_parquet('{labs_filepath}')
    WHERE hospitalization_id IN (SELECT hospitalization_id FROM final_cohort_for_labs)
),
labs_with_death AS (
    SELECT 
        l.hospitalization_id,
        l.lab_collect_dttm,
        l.lab_category,
        l.lab_value_numeric,
        f.patient_id,
        f.final_death_dttm
    FROM labs_data l
    INNER JOIN final_cohort_for_labs f ON l.hospitalization_id = f.hospitalization_id
    WHERE l.lab_collect_dttm <= f.final_death_dttm
),
-- Get latest creatinine per hospitalization
latest_creatinine AS (
    SELECT 
        hospitalization_id,
        lab_value_numeric AS creatinine_value,
        lab_collect_dttm AS creatinine_dttm
    FROM (
        SELECT 
            hospitalization_id,
            lab_value_numeric,
            lab_collect_dttm,
            ROW_NUMBER() OVER (PARTITION BY hospitalization_id ORDER BY lab_collect_dttm DESC) AS rn
        FROM labs_with_death
        WHERE lab_category = 'creatinine'
    ) ranked
    WHERE rn = 1
),
-- Get latest liver labs per hospitalization
latest_liver AS (
    SELECT 
        hospitalization_id,
        MAX(CASE WHEN lab_category = 'bilirubin_total' THEN lab_value_numeric END) AS bilirubin_total_value,
        MAX(CASE WHEN lab_category = 'bilirubin_total' THEN lab_collect_dttm END) AS bilirubin_total_dttm,
        MAX(CASE WHEN lab_category = 'ast' THEN lab_value_numeric END) AS ast_value,
        MAX(CASE WHEN lab_category = 'ast' THEN lab_collect_dttm END) AS ast_dttm,
        MAX(CASE WHEN lab_category = 'alt' THEN lab_value_numeric END) AS alt_value,
        MAX(CASE WHEN lab_category = 'alt' THEN lab_collect_dttm END) AS alt_dttm
    FROM (
        SELECT 
            hospitalization_id,
            lab_category,
            lab_value_numeric,
            lab_collect_dttm,
            ROW_NUMBER() OVER (PARTITION BY hospitalization_id, lab_category ORDER BY lab_collect_dttm DESC) AS rn
        FROM labs_with_death
        WHERE lab_category IN ('bilirubin_total', 'ast', 'alt')
    ) ranked
    WHERE rn = 1
    GROUP BY hospitalization_id
),
-- Combine all labs
organ_labs AS (
    SELECT DISTINCT
        f.patient_id,
        c.creatinine_value,
        c.creatinine_dttm,
        l.bilirubin_total_value,
        l.bilirubin_total_dttm,
        l.ast_value,
        l.ast_dttm,
        l.alt_value,
        l.alt_dttm
    FROM final_cohort_for_labs f
    LEFT JOIN latest_creatinine c ON f.hospitalization_id = c.hospitalization_id
    LEFT JOIN latest_liver l ON f.hospitalization_id = l.hospitalization_id
)
SELECT * FROM organ_labs
"""

organ_labs_result = duckdb.sql(labs_query).df()
organ_labs = pl.from_pandas(organ_labs_result)

print(f"✓ Organ labs loaded: {len(organ_labs)} patients")
print(f"  Patients with creatinine: {organ_labs.filter(pl.col('creatinine_value').is_not_null())['patient_id'].n_unique()}")
print(f"  Patients with bilirubin: {organ_labs.filter(pl.col('bilirubin_total_value').is_not_null())['patient_id'].n_unique()}")
print(f"  Patients with AST: {organ_labs.filter(pl.col('ast_value').is_not_null())['patient_id'].n_unique()}")
print(f"  Patients with ALT: {organ_labs.filter(pl.col('alt_value').is_not_null())['patient_id'].n_unique()}")

# Join organ_labs with final_cohort_df on patient_id
final_cohort_df = final_cohort_df.join(
    organ_labs, 
    on='patient_id', 
    how='left', 
    suffix='_organlab'
)
print(f"✓ Labs processing complete")

In [None]:
# Join organ_labs with final_cohort_for_labs on patient_id
final_cohort_df = final_cohort_df.join(
    organ_labs, 
    on='patient_id', 
    how='left', 
    suffix='_organlab'
)
print(f"Final cohort with organ labs shape: {final_cohort_df.shape}")

In [None]:
# ============================================
# Create organ quality assessment flags
# ============================================
final_cohort_df = final_cohort_df.with_columns([
    # Kidney criteria: creatinine < 4 AND not on CRRT
    (
        (pl.col('creatinine_value').is_not_null()) &
        (pl.col('creatinine_value') < 4) &
        (~pl.col('on_crrt_48h_before_death'))
    ).alias('kidney_eligible'),

    # Liver criteria: all three labs recorded AND values within limits
    (
        (pl.col('bilirubin_total_value').is_not_null()) &
        (pl.col('ast_value').is_not_null()) &
        (pl.col('alt_value').is_not_null()) &
        (pl.col('bilirubin_total_value') < 4) &
        (pl.col('ast_value') < 700) &
        (pl.col('alt_value') < 700)
    ).alias('liver_eligible'),

    # BMI criteria: <= 50
    (
        (pl.col('bmi').is_not_null()) &
        (pl.col('bmi') <= 50)
    ).alias('bmi_eligible'),
])

# Overall: (kidney OR liver) AND BMI - done in separate call
final_cohort_df = final_cohort_df.with_columns([
    (
        (
            pl.col('kidney_eligible') | pl.col('liver_eligible')
        ) &
        pl.col('bmi_eligible')
    ).alias('organ_check_pass')
])

# Count for STROBE tracking
kidney_eligible_n = final_cohort_df.filter(pl.col('kidney_eligible'))['patient_id'].n_unique()
liver_eligible_n = final_cohort_df.filter(pl.col('liver_eligible'))['patient_id'].n_unique()
bmi_eligible_n = final_cohort_df.filter(pl.col('bmi_eligible'))['patient_id'].n_unique()
organ_check_pass_n = final_cohort_df.filter(pl.col('organ_check_pass'))['patient_id'].n_unique()

strobe_counts["organ_kidney_eligible"] = kidney_eligible_n
strobe_counts["organ_liver_eligible"] = liver_eligible_n
strobe_counts["organ_bmi_eligible"] = bmi_eligible_n
strobe_counts["organ_check_pass"] = organ_check_pass_n

print(f"\nOrgan Quality Assessment:")
print(f"  Kidney eligible: {kidney_eligible_n} patients")
print(f"  Liver eligible: {liver_eligible_n} patients")
print(f"  BMI eligible: {bmi_eligible_n} patients")
print(f"  Overall organ check pass: {organ_check_pass_n} patients")

# Microbiology

Identify negative blood cultures and patients with no cultures in last 48h

In [None]:
micro_culture_filepath = f"{tables_path}/clif_microbiology_culture.{file_type}"
micro_culture = read_data(
      micro_culture_filepath,
      file_type,
      filter_ids=all_decedent_inpatient_hosp_ids,
      id_column='hospitalization_id'
  )

In [None]:
# Microbiology - use DuckDB to avoid Polars crashes
print("Processing microbiology data with DuckDB...")

final_cohort_for_micro = final_cohort_df.select([
    'hospitalization_id', 'final_death_dttm'
]).to_pandas()

micro_query = f"""
WITH blood_cultures AS (
    SELECT 
        hospitalization_id,
        collect_dttm,
        organism_category
    FROM read_parquet('{tables_path}/clif_microbiology_culture.{file_type}')
    WHERE fluid_category = 'blood_buffy'
        AND method_category = 'culture'
        AND hospitalization_id IN (SELECT hospitalization_id FROM final_cohort_for_micro)
),
cultures_with_death AS (
    SELECT 
        b.hospitalization_id,
        b.collect_dttm,
        b.organism_category,
        f.final_death_dttm,
        EXTRACT(EPOCH FROM (f.final_death_dttm - b.collect_dttm)) / 3600 AS hrs_before_death
    FROM blood_cultures b
    INNER JOIN final_cohort_for_micro f ON b.hospitalization_id = f.hospitalization_id
    WHERE b.collect_dttm IS NOT NULL
),
cultures_48h AS (
    SELECT 
        *,
        CASE 
            WHEN LOWER(organism_category) LIKE '%no_growth%' 
                OR organism_category IS NULL 
                OR LOWER(organism_category) = '' 
            THEN true 
            ELSE false 
        END AS is_negative_culture
    FROM cultures_with_death
    WHERE hrs_before_death >= 0 AND hrs_before_death <= 48
),
positive_cultures AS (
    SELECT DISTINCT hospitalization_id
    FROM cultures_48h
    WHERE is_negative_culture = false
)
SELECT 
    f.hospitalization_id,
    CASE WHEN p.hospitalization_id IS NULL THEN true ELSE false END AS no_positive_culture_48hrs
FROM final_cohort_for_micro f
LEFT JOIN positive_cultures p ON f.hospitalization_id = p.hospitalization_id
"""

no_positive_culture_flag_pd = duckdb.sql(micro_query).df()
no_positive_culture_flag = pl.from_pandas(no_positive_culture_flag_pd)

# Join flag to final_cohort_df
final_cohort_df = final_cohort_df.join(
    no_positive_culture_flag,
    on='hospitalization_id',
    how='left'
)

# Fill any nulls with False
final_cohort_df = final_cohort_df.with_columns(
    pl.col('no_positive_culture_48hrs').fill_null(False)
)

# Count for STROBE tracking
no_positive_culture_n = final_cohort_df.filter(pl.col('no_positive_culture_48hrs'))['patient_id'].n_unique()
positive_culture_n = final_cohort_df.filter(~pl.col('no_positive_culture_48hrs'))['patient_id'].n_unique()

strobe_counts["no_positive_culture_48hrs"] = no_positive_culture_n
strobe_counts["positive_culture_48hrs"] = positive_culture_n

print(f"\nBlood Culture Results:")
print(f"  Patients with no positive cultures in last 48h: {no_positive_culture_n}")
print(f"  Patients with positive cultures in last 48h: {positive_culture_n}")

In [None]:
final_cohort_df.columns

# CLIF Eligible Donor

Medically eligible potential deceased abdominal organ donor (CLIF-eligible-donors):  


* From ALL inpatient deaths (ensure death location = ED, ward, stepdown, ICU)
* Age < 75
* On invasive mechanical ventilation
* IF death date/time available: within 48h of death
* IF no death date/time available: at time of last recorded vital signs
* No contraindications
* CLIF Microbiology_culture:
    * No positive blood cultures within 2 days - 'no_positive_culture_48hrs'
* Hospital diagnosis (ICD based) -- 'icd10_contraindication',
    * Cancer
    * Severe sepsis
* Pass the potential organ quality assessment check (independent assessment) using last recorded lab values, as defined by CMS:- organ_check_pass
    * Kidney: recorded creatinine, Cr < 4 AND not on CRRT
    * Liver: recorded TB, AST, ALT and
        * Total bilirubin < 4
        * AST < 700
        * ALT < 700
    * BMI <= 50

In [None]:
# ============================================
# Create CLIF-eligible-donors flag
# ============================================


final_cohort_df = final_cohort_df.with_columns([
    # Overall CLIF-eligible-donors flag
    (
        # 2. Age < 75
        (pl.col('age_75_less')) &
        # 3. On invasive mechanical ventilation (within 48h of death)
        (pl.col('imv_48hr_expire')) &
        # 4. No contraindications (no cancer, no severe sepsis)
        (~pl.col('icd10_contraindication')) &
        # 5. No positive blood cultures within 48h
        (pl.col('no_positive_culture_48hrs')) &
        # 6. Pass organ quality assessment (kidney OR liver AND BMI)
        (pl.col('organ_check_pass'))
    ).alias('clif_eligible_donors')
])

# Count for STROBE tracking
clif_eligible_n = final_cohort_df.filter(pl.col('clif_eligible_donors'))['patient_id'].n_unique()
strobe_counts["clif_eligible_donors"] = clif_eligible_n

# Patient assessments

In [None]:
patient_assessments_filepath = f"{tables_path}/clif_patient_assessments.{file_type}"
patient_assessments_df = read_data(
      patient_assessments_filepath,
      file_type,
      filter_ids=all_decedent_inpatient_hosp_ids,
      id_column='hospitalization_id'
  )

In [None]:
# Patient assessments - use DuckDB to avoid crashes
print("Processing patient assessments with DuckDB...")

final_cohort_for_assessments = final_cohort_df.select([
    "hospitalization_id", "final_death_dttm"
]).to_pandas()

assessments_query = f"""
WITH assessments_filtered AS (
    SELECT 
        hospitalization_id,
        recorded_dttm,
        LOWER(assessment_category) AS assessment_category,
        numerical_value
    FROM read_parquet('{tables_path}/clif_patient_assessments.{file_type}')
    WHERE LOWER(assessment_category) IN ('gcs_total', 'rass')
        AND numerical_value IS NOT NULL
        AND hospitalization_id IN (SELECT hospitalization_id FROM final_cohort_for_assessments)
),
with_death_time AS (
    SELECT 
        a.hospitalization_id,
        a.recorded_dttm,
        a.assessment_category,
        a.numerical_value,
        f.final_death_dttm,
        ABS(EXTRACT(EPOCH FROM (f.final_death_dttm - a.recorded_dttm))) AS abs_time_to_death
    FROM assessments_filtered a
    INNER JOIN final_cohort_for_assessments f ON a.hospitalization_id = f.hospitalization_id
),
closest_per_category AS (
    SELECT 
        hospitalization_id,
        assessment_category,
        numerical_value,
        ROW_NUMBER() OVER (
            PARTITION BY hospitalization_id, assessment_category 
            ORDER BY abs_time_to_death
        ) AS rn
    FROM with_death_time
)
SELECT 
    hospitalization_id,
    MAX(CASE WHEN assessment_category = 'gcs_total' THEN numerical_value END) AS gcs_total_value,
    MAX(CASE WHEN assessment_category = 'rass' THEN numerical_value END) AS rass_value
FROM closest_per_category
WHERE rn = 1
GROUP BY hospitalization_id
"""

patient_gcs_rass_pd = duckdb.sql(assessments_query).df()
patient_gcs_rass = pl.from_pandas(patient_gcs_rass_pd)

print(f"✓ Processed assessments for {len(patient_gcs_rass)} hospitalizations")

# Join to final_cohort_df
final_cohort_df = final_cohort_df.join(
    patient_gcs_rass,
    on='hospitalization_id',
    how='left'
)

print("✓ Patient assessments processing complete")

In [None]:
# ================================================================================
# ENSURE PATIENT-LEVEL ANALYSIS
# ================================================================================
print("\n" + "="*80)
print("FINALIZING PATIENT-LEVEL COHORT")
print("="*80)

# Step 1: First, remove encounter-level identifiers
print("Step 1: Removing encounter-level identifiers (hospitalization_id, encounter_block)...")
columns_to_drop = []
if 'hospitalization_id' in final_cohort_df.columns:
    columns_to_drop.append('hospitalization_id')
if 'encounter_block' in final_cohort_df.columns:
    columns_to_drop.append('encounter_block')

if columns_to_drop:
    final_cohort_df = final_cohort_df.drop(columns_to_drop)
    print(f"✓ Dropped: {', '.join(columns_to_drop)}")
else:
    print("✓ No encounter-level identifiers found to drop")

# Step 2: Now deduplicate to ensure one row per patient
print("\nStep 2: Ensuring one row per patient...")
n_patients_before = final_cohort_df['patient_id'].n_unique()
n_rows_before = len(final_cohort_df)

if n_patients_before != n_rows_before:
    print(f"WARNING: {n_rows_before:,} rows but only {n_patients_before:,} unique patients!")
    print("Deduplicating to ensure one row per patient...")

    # Deduplicate keeping the last entry per patient (most recent data)
    final_cohort_df = final_cohort_df.unique(subset=['patient_id'], keep='last')

    n_rows_after = len(final_cohort_df)
    print(f"✓ Deduplicated: {n_rows_before:,} rows → {n_rows_after:,} rows")
    print(f"Removed {n_rows_before - n_rows_after:,} duplicate rows")
else:
    print(f"✓ Already unique: {n_patients_before:,} patients = {n_rows_before:,} rows")

# Step 3: Final verification
n_patients_final = final_cohort_df['patient_id'].n_unique()
n_rows_final = len(final_cohort_df)
assert n_patients_final == n_rows_final, \
    f"CRITICAL: Still have duplicates! {n_rows_final} rows but {n_patients_final} unique patients"

print(f"\n✓ Final verification passed: {n_patients_final:,} unique patients")
print(f"Final cohort shape: {final_cohort_df.shape}")
print("="*80 + "\n")

In [None]:
final_cohort_df.write_parquet(str(OUTPUT_INTERMEDIATE_DIR / "final_cohort_df.parquet"))
pd.DataFrame([strobe_counts]).to_csv(str(OUTPUT_FINAL_DIR / "strobe_counts.csv"), index=False)

# Table One

In [None]:
from utils.table_one import create_table_one
table_one = create_table_one(final_cohort_df, output_dir=str(OUTPUT_FINAL_DIR))

# Visualizations

In [None]:
strobe_counts

In [None]:
from utils.cohort_visualizations import create_all_visualizations
summary_df = create_all_visualizations(final_cohort_df, output_dir=str(OUTPUT_FINAL_DIR))

# STROBE

In [None]:
from utils.strobe_diagram import create_strobe_diagrams_for_cohorts
results = create_strobe_diagrams_for_cohorts(
      final_cohort_df,
      output_dir=str(OUTPUT_FINAL_DIR),
      save_figures=True,
      save_csvs=True
  )

# Access results
calc_stages = results['CALC']['stages']
clif_stages = results['CLIF']['stages']