## CRRT Cohort Check

Required Checks for hospitalizations since 2021:

1. Definition I  : Hospitalizations that are on a ventilator for the first 24 hours of their first icu stay.
2. Definition II : Hospitalizations that are on vasoactive medications during the first 24 hrs of their first ICU stay. 
3. Definition III: Hospitalizations that have stage I AKI defined as 

    3a. 0.3 mg/dl absolute increase in serum creatinine over a 48 hour period since admission   

    3b. 50% increase in serum creatinine over 7 days 

## 00 Load libraries and core CLIF tables


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import json

import pyCLIF
import pyCLIF_mimic
import waterfall
## import outlier json
with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
    outlier_cfg = json.load(f)

In [None]:
patient = pyCLIF.load_data('clif_patient')
hospitalization = pyCLIF.load_data('clif_hospitalization')
adt = pyCLIF.load_data('clif_adt')

# ensure id variable is of dtype character
hospitalization['hospitalization_id']= hospitalization['hospitalization_id'].astype(str)
patient['patient_id']= patient['patient_id'].astype(str)
adt['hospitalization_id']= adt['hospitalization_id'].astype(str)

# check for duplicates
# patient table should be unique by patient id
patient = pyCLIF.remove_duplicates(patient, ['patient_id'], 'patient')
# hospitalization table should be unique by hospitalization id
hospitalization = pyCLIF.remove_duplicates(hospitalization, ['hospitalization_id'], 'hospitalization')
# adt table should be unique by hospitalization id and in dttm
adt = pyCLIF.remove_duplicates(adt, ['hospitalization_id', 'hospital_id', 'in_dttm'], 'adt')

In [None]:
# Standardize all _dttm variables to the same format
patient = pyCLIF.convert_datetime_columns_to_site_tz(patient,  pyCLIF.helper['timezone'])
hospitalization = pyCLIF.convert_datetime_columns_to_site_tz(hospitalization, pyCLIF.helper['timezone'])
adt = pyCLIF.convert_datetime_columns_to_site_tz(adt,  pyCLIF.helper['timezone'])

#### Hospitalizations

In [None]:
cohort = hospitalization[(hospitalization['admission_dttm'].dt.year >= 2021) & 
                   (hospitalization['admission_dttm'].dt.year <= 2024) & 
                   (hospitalization['age_at_admission'] >=18)&
                    (hospitalization['age_at_admission'] <=119)]

In [None]:
strobe_counts = {}
strobe_counts["A_adult_hospitalizations_since_2021"] = len(cohort['hospitalization_id'].drop_duplicates())

#### ADT

In [None]:
# Convert location category to lowercase and filter for ICU
# Filter ADT table to include only hospitalizations from the cohort
adt_cohort = adt[adt['hospitalization_id'].isin(cohort['hospitalization_id'])]
adt_cohort['location_category'] = adt_cohort['location_category'].str.lower()
# Filter to encounters that had at least one ICU stay
icu_hospitalization_ids = adt_cohort[adt_cohort['location_category'] == 'icu']['hospitalization_id'].unique()
adt_filtered = adt_cohort[adt_cohort['hospitalization_id'].isin(icu_hospitalization_ids)]
strobe_counts["B_adult_hospitalizations_since_2021_with_icu"] = len(adt_filtered['hospitalization_id'].drop_duplicates())

In [None]:
strobe_counts

In [None]:
cohort = cohort[cohort['hospitalization_id'].isin(adt_filtered['hospitalization_id'])]
print("Final list of cohort ids", len(cohort['hospitalization_id'].drop_duplicates()))

# Hourly Scaffold

In [None]:
# 1) define the 'end_time' for the sequence from vitals or outcome.
vitals_cohort = pyCLIF.load_data('clif_vitals',
    filters={'hospitalization_id': cohort['hospitalization_id'].unique().tolist()}
)
vitals_cohort = pyCLIF.convert_datetime_columns_to_site_tz(vitals_cohort, pyCLIF.helper['timezone'])
vitals_cohort = vitals_cohort.sort_values(['hospitalization_id', 'recorded_dttm'])

# Get first and last vitals timestamp for each hospitalization
vital_bounds = (
    vitals_cohort
    .groupby('hospitalization_id')
    .agg({
        'recorded_dttm': ['min', 'max']
    })
    .droplevel(0, axis=1)
    .rename(columns={'min': 'first_vital_dttm', 'max': 'last_vital_dttm'})
)

# Create hourly scaffold for each hospitalization
hourly_scaffold = pd.DataFrame([
    (hosp_id, time)
    for hosp_id, start, end in zip(
        vital_bounds.index,
        vital_bounds['first_vital_dttm'],
        vital_bounds['last_vital_dttm']
    )
    for time in pd.date_range(start=start, end=end, freq='H', tz=pyCLIF.helper['timezone'])
], columns=['hospitalization_id', 'recorded_dttm'])

# Add date and hour columns
hourly_scaffold['recorded_date'] = hourly_scaffold['recorded_dttm'].dt.date
hourly_scaffold['recorded_hour'] = hourly_scaffold['recorded_dttm'].dt.hour

# Definition I

Hospitalizations that are on a ventilator for the first 24 hours of their first icu stay.

Notes: 

- Use ADT table to identify hospitalizations first ICU stay ; location_category.lower == "icu". Fields in ADT table = hospitalization_id, location_category, in_dttm, out_dttm
- Use Respiratory Support table to identify the duration of ventilator for the first ICU stay. Use device_category.lower() == "imv" to identify those on vent. Other vars in the table- hospitalization_id, recorded_dttm, device_category, mode_category
- Identify hospitalizations that were on vent for the first 24 hours of their first ICU stay


#### First ICU Stay

In [None]:
# Get the first ICU stay for each hospitalization 
# Convert location category to lowercase and filter for ICU
icu_stays = adt_filtered[adt_filtered['location_category'] == 'icu'].copy()
# Sort by hospitalization_id and in_dttm to get first ICU stay
icu_stays = icu_stays.sort_values(['hospitalization_id', 'in_dttm', 'out_dttm'])
# Get first ICU stay for each hospitalization
first_icu_stays = icu_stays.groupby('hospitalization_id').first().reset_index()

# Calculate duration of first ICU stay
first_icu_stays['icu_duration_hours'] = (
    first_icu_stays['out_dttm'] - first_icu_stays['in_dttm']
).dt.total_seconds() / 3600

first_icu_stays['icu_duration_days'] = first_icu_stays['icu_duration_hours'] / 24

# Check how many have at least 24 hours (1 day) ICU stay
icu_24h_plus = first_icu_stays[first_icu_stays['icu_duration_hours'] >= 24]
print(f"\nHospitalizations with ICU stay ≥ 24 hours: {len(icu_24h_plus)} ({len(icu_24h_plus)/len(first_icu_stays)*100:.1f}%)")

In [None]:
# After creating hourly_scaffold but BEFORE merging any clinical data
# Add ICU admission information to ALL patients
hourly_scaffold = hourly_scaffold.merge(
    first_icu_stays[['hospitalization_id', 'in_dttm', 'out_dttm']],
    on='hospitalization_id',
    how='left'
)

# Convert to UTC first to avoid DST issues, then floor, then convert back
def safe_floor_datetime(series, freq='S'):
    """Safely floor datetime series handling DST transitions"""
    # Convert to UTC, floor, then back to original timezone
    utc_series = series.dt.tz_convert('UTC')
    floored_utc = utc_series.dt.floor(freq)
    return floored_utc.dt.tz_convert(series.dt.tz)

# Apply safe flooring
hourly_scaffold['recorded_dttm'] = safe_floor_datetime(hourly_scaffold['recorded_dttm'])
hourly_scaffold['in_dttm'] = safe_floor_datetime(hourly_scaffold['in_dttm'])
hourly_scaffold['out_dttm'] = safe_floor_datetime(hourly_scaffold['out_dttm'])

# Step 1: First, identify which hours each patient was ACTUALLY in ICU
hourly_scaffold['in_icu'] = (
    (hourly_scaffold['recorded_dttm'] >= hourly_scaffold['in_dttm']) &
    (hourly_scaffold['recorded_dttm'] <= hourly_scaffold['out_dttm'])
).astype(int)

# Step 2: Calculate actual ICU duration for filtering
hourly_scaffold['icu_duration_hours'] = (
    hourly_scaffold['out_dttm'] - hourly_scaffold['in_dttm']
).dt.total_seconds() / 3600

# Step 3: Only for patients with ≥24h ICU stay, mark their first 24 hours
hourly_scaffold['hours_since_icu_admission'] = (
    hourly_scaffold['recorded_dttm'] - hourly_scaffold['in_dttm']
).dt.total_seconds() / 3600

hourly_scaffold['in_icu_24h'] = (
    (hourly_scaffold['icu_duration_hours'] >= 24) &  # Patient stayed ≥24h in ICU
    (hourly_scaffold['in_icu'] == 1) &               # Patient was in ICU at this hour
    (hourly_scaffold['hours_since_icu_admission'] >= 0) &  # At or after ICU admission
    (hourly_scaffold['hours_since_icu_admission'] < 24)    # Within first 24 hours
).astype(int)

In [None]:
# Check how many patients have in_icu_24h flag at all
patients_with_flag = hourly_scaffold[hourly_scaffold['in_icu_24h'].notna()]['hospitalization_id'].nunique()
print(f"Patients with in_icu_24h flag: {patients_with_flag}")

# Check how many patients have in_icu_24h == 1
patients_with_24h = hourly_scaffold[hourly_scaffold['in_icu_24h'] == 1]['hospitalization_id'].nunique()
print(f"Patients with in_icu_24h == 1: {patients_with_24h}")

strobe_counts["C_adult_hospitalizations_since_2021_with_icu_atleast_24hr"] = hourly_scaffold[hourly_scaffold['in_icu_24h'] == 1]['hospitalization_id'].nunique()

In [None]:
# Check if icu_24h_plus exists and matches patients_with_24h
try:
    if len(icu_24h_plus) != patients_with_24h:
        print(f"WARNING: icu_24h_plus length ({len(icu_24h_plus)}) does not match patients_with_24h ({patients_with_24h})")
    else:
        print(f"✓ icu_24h_plus length matches patients_with_24h: {patients_with_24h}")
except NameError:
    print("WARNING: Different aggregates for hospitalizations in ICU for 24 hrs. Check results/data")

In [None]:
cohort_final = cohort.copy()

In [None]:
# Join admission_dttm with hourly_scaffold
hourly_scaffold = hourly_scaffold.merge(
    hospitalization[['hospitalization_id', 'admission_dttm']],
    on='hospitalization_id',
    how='left'
)

#### Respiratory Support

In [None]:
# load resp support 
rst_required_columns = [
    'hospitalization_id',
    'recorded_dttm',
    'device_name',
    'device_category',
    'mode_name', 
    'mode_category',
    'tracheostomy',
    'fio2_set',
    'lpm_set',
    'resp_rate_set',
    'peep_set',
    'resp_rate_obs',
    'tidal_volume_set', 
    'pressure_control_set',
    'pressure_support_set',
    'peak_inspiratory_pressure_set'

]

# 1) Load respiratory support
resp_support_raw = pyCLIF.load_data(
    'clif_respiratory_support',
    columns=rst_required_columns,
    filters={'hospitalization_id': cohort_final['hospitalization_id'].unique().tolist()}
)

resp_support = resp_support_raw.copy()
resp_support['device_category'] = resp_support['device_category'].str.lower()
resp_support['mode_category'] = resp_support['mode_category'].str.lower()
resp_support['lpm_set'] = pd.to_numeric(resp_support['lpm_set'], errors='coerce')
resp_support['resp_rate_set'] = pd.to_numeric(resp_support['resp_rate_set'], errors='coerce')
resp_support['peep_set'] = pd.to_numeric(resp_support['peep_set'], errors='coerce')
resp_support['resp_rate_obs'] = pd.to_numeric(resp_support['resp_rate_obs'], errors='coerce')
resp_support = resp_support.sort_values(['hospitalization_id', 'recorded_dttm'])
# del resp_support_raw

print("\n=== Apply outlier thresholds ===\n")
resp_support['fio2_set'] = pd.to_numeric(resp_support['fio2_set'], errors='coerce')
# (Optional) If FiO2 is >1 on average => scale by /100
fio2_mean = resp_support['fio2_set'].mean(skipna=True)
# If the mean is greater than 1, divide 'fio2_set' by 100
if fio2_mean and fio2_mean > 1.0:
    # Only divide values greater than 1 to avoid re-dividing already correct values
    resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] = \
        resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] / 100
    print("Updated fio2_set to be between 0.21 and 1")
else:
    print("FIO2_SET mean=", fio2_mean, "is within the required range")

In [None]:
## Identify encounters on IMV
# Create mask to identify IMV entries
imv_mask = resp_support['device_category'].str.contains("imv", case=False, na=False)

# Get unique hospitalization_ids with at least one IMV entry
resp_stitched_imv_ids = resp_support[imv_mask][['hospitalization_id']].drop_duplicates()

strobe_counts["D_adult_hospitalizations_since_2021_with_icu_imv"] = len(resp_stitched_imv_ids['hospitalization_id'].drop_duplicates())
# Filter the full table to just these hospitalization_ids
resp_support_filtered = resp_support[
    resp_support["hospitalization_id"].isin(resp_stitched_imv_ids["hospitalization_id"])
].reset_index(drop=True)

# filter down to only those hospitalization_ids that are in the cohort
all_ids = cohort[cohort['hospitalization_id'].isin(resp_support_filtered['hospitalization_id'].unique())]

In [None]:
strobe_counts

In [None]:
print("Final list of ids adult_hospitalizations_since_2021_with_icu_imv", len(all_ids['hospitalization_id'].drop_duplicates()))

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="pandas")

processed_resp_support = waterfall.process_resp_support_waterfall(resp_support_filtered, 
                                                        id_col = "hospitalization_id",
                                                        verbose = True)

processed_resp_support = pyCLIF.convert_datetime_columns_to_site_tz(processed_resp_support, pyCLIF.helper['timezone'])

In [None]:
vent_records = processed_resp_support.merge(
    first_icu_stays[['hospitalization_id', 'in_dttm', 'out_dttm']], 
    on='hospitalization_id', 
    how='inner'
)

# Create on_vent column (1 when device_category is IMV, 0 otherwise)
vent_records['on_vent'] = (vent_records['device_category'].str.lower() == 'imv').astype(int)

While aggregating flags at the hourly level, I used the last value during that hour assuming the last value better represents the patient's status going into the next hour. 

In [None]:
cohort_hourly = vent_records[['hospitalization_id', 'recorded_dttm', 'on_vent']]

# Create recorded_date and recorded_hour columns
cohort_hourly['recorded_date'] = cohort_hourly['recorded_dttm'].dt.date
cohort_hourly['recorded_hour'] = cohort_hourly['recorded_dttm'].dt.hour

# Aggregate by hospitalization_id, recorded_date, and recorded_hour
# First sort by time and get last value in hour, preserving the actual timestamp
cohort_hourly_agg = (
    cohort_hourly
    .sort_values(['hospitalization_id', 'recorded_dttm'])
    .groupby(['hospitalization_id', 'recorded_date', 'recorded_hour'])
    .agg({
        'on_vent': 'last',        # Last vent status in hour
    })
    .reset_index()
)

In [None]:
final_df = hourly_scaffold.merge(
    cohort_hourly_agg,
    on=['hospitalization_id', 'recorded_date', 'recorded_hour'],
    how='left'
)

In [None]:
#  forward fill missing hours
final_df = (
    final_df
    .set_index('hospitalization_id')
    .groupby('hospitalization_id')
    .ffill()
    .reset_index()
)

final_df = (
    final_df
    .sort_values(['hospitalization_id', 'recorded_date', 'recorded_hour'])
    .groupby(['hospitalization_id', 'recorded_date', 'recorded_hour'])
    .last()
    .reset_index()
)

In [None]:
# identify hosp satisfying def 1
def_1_status = (
    final_df[final_df['in_icu_24h'] == 1]
    .groupby('hospitalization_id')
    .agg({'on_vent': 'min'})
    .reset_index() 
)
def_1_status['def_1'] = (def_1_status['on_vent'] == 1).astype(int)

In [None]:
# Count patients with ventilation data in first 24h of ICU
patients_with_vent_data_24h = int(def_1_status['on_vent'].notna().sum())
strobe_counts["D1_patients_with_vent_data_in_first_24h_icu"] = patients_with_vent_data_24h

# Count patients meeting Definition 1 (on ventilator for first 24h)
patients_meeting_def1 = int(def_1_status['def_1'].sum())
strobe_counts["D2_patients_meeting_definition_1_24h_ventilation"] = patients_meeting_def1

# Count patients without ventilation data in first 24h
patients_without_vent_data_24h = int(def_1_status['on_vent'].isna().sum())
strobe_counts["D3_patients_without_vent_data_in_first_24h_icu"] = patients_without_vent_data_24h

# Print summary
print(f"\n=== DEFINITION 1 SUMMARY ===")
print(f"Total eligible patients (≥24h ICU): {len(def_1_status)}")
print(f"With ventilation data in first 24h: {patients_with_vent_data_24h} ({patients_with_vent_data_24h/len(def_1_status)*100:.1f}%)")
print(f"Without ventilation data in first 24h: {patients_without_vent_data_24h} ({patients_without_vent_data_24h/len(def_1_status)*100:.1f}%)")
print(f"Meeting Definition 1 (24h ventilation): {patients_meeting_def1} ({patients_meeting_def1/len(def_1_status)*100:.1f}%)")
print(f"Meeting Definition 1 among those with vent data: {patients_meeting_def1}/{patients_with_vent_data_24h} ({patients_meeting_def1/patients_with_vent_data_24h*100:.1f}%)")

# Verify totals
print(f"\nVerification:")
print(f"With vent data + Without vent data = {patients_with_vent_data_24h + patients_without_vent_data_24h} (should equal {len(def_1_status)})")

In [None]:
# Get all hospitalization IDs from final_df
all_hosp_ids = final_df['hospitalization_id'].drop_duplicates().reset_index(drop=True).to_frame()

# Join with def_1_status and replace missing with 0
all_defs = (
    all_hosp_ids
    .merge(def_1_status[['hospitalization_id', 'def_1']], on='hospitalization_id', how='left')
    .fillna({'def_1': 0})
)
all_defs['def_1'] = all_defs['def_1'].astype(int)

# Definition II
Hospitalizations that are on vasoactive medications during the first 24 hrs of their first ICU stay.


#### Medication Admin Continuous

- Filter down to the required meds and the cohort
- Identify if any of these meds were administered continuously during that hour, and create a flag for each med at the hourly level. 

In [None]:
meds_required_columns = [
    'hospitalization_id',
    'admin_dttm',
    'med_name',
    'med_category',
    'med_dose',
    'med_dose_unit'
]
meds_of_interest = [
    'norepinephrine', 'epinephrine', 'phenylephrine', 'vasopressin',
    'dopamine', 'angiotensin','dobutamine'
]

meds_filters = {
    'hospitalization_id': cohort_final['hospitalization_id'].unique().tolist(),
    'med_category': meds_of_interest
}
meds = pyCLIF.load_data('clif_medication_admin_continuous', 
                        columns=meds_required_columns, 
                        filters=meds_filters)

# ensure correct format
meds['hospitalization_id']= meds['hospitalization_id'].astype(str)
meds['med_dose_unit'] = meds['med_dose_unit'].str.lower()
meds = pyCLIF.convert_datetime_columns_to_site_tz(meds,  pyCLIF.helper['timezone'])
meds['med_dose'] = pd.to_numeric(meds['med_dose'], errors='coerce')
# Create 'date' and 'hour_of_day' columns
meds['recorded_date'] = meds['admin_dttm'].dt.date
meds['recorded_hour'] = meds['admin_dttm'].dt.hour

strobe_counts["E_adult_hospitalizations_since_2021_icu_meds"] = len(meds['hospitalization_id'].drop_duplicates())
strobe_counts

In [None]:
# Filter meds_filtered for the medications in red_meds_list
meds_filtered = meds[meds['med_category'].isin(meds_of_interest)].copy()

# Create a flag for each medication in red_meds_list
for med in meds_of_interest:
    # Create a flag that is 1 if the medication was administered in that hour, 0 otherwise
    meds_filtered[med + '_flag'] = np.where((meds_filtered['med_category'] == med) & 
                                         (meds_filtered['med_dose'] > 0.0) & 
                                         (meds_filtered['med_dose'].notna()), 1, 0).astype(int)

# Aggregate to get the maximum value for each flag (per hospitalization_id, recorded_date, recorded_hour)
# This ensures that if the medication was administered even once in the hour, the flag is 1
meds_flags = meds_filtered.groupby(['hospitalization_id', 'recorded_date', 'recorded_hour']).agg(
    {med + '_flag': 'max' for med in meds_of_interest}
).reset_index()

#  combine all flags into a single 'red_meds_flag', you can do so like this:
meds_flags['vasoactive_meds_flag'] = meds_flags[[med + '_flag' for med in meds_of_interest]].max(axis=1)

In [None]:
final_df= final_df.merge(
    meds_flags[['hospitalization_id', 'recorded_date', 'recorded_hour', 'vasoactive_meds_flag']],
    on=['hospitalization_id', 'recorded_date', 'recorded_hour'],
    how='left'
)

# Forward fill vasoactive_meds_flag within each hospitalization
final_df = (
    final_df
    .sort_values(['hospitalization_id', 'recorded_date', 'recorded_hour'])
    .assign(vasoactive_meds_flag=lambda x: x.groupby('hospitalization_id')['vasoactive_meds_flag'].ffill())
)

In [None]:
# Calculate def_2 flag for vasoactive medications in first 24h
def_2_status = (
    final_df[final_df['in_icu_24h'] == 1]  # Only look at records in first 24h
    .groupby('hospitalization_id')
    .agg({
        'vasoactive_meds_flag': 'min'  # 1 if ALL hour in first 24h had vasoactive meds
    })
    .reset_index()
)

def_2_status['def_2'] = (def_2_status['vasoactive_meds_flag'] == 1).astype(int)

# Count hospitalizations with vasoactive meds data in first 24h of ICU
strobe_counts['E1_patients_with_meds_data_in_first_24h_icu'] = int(def_2_status['vasoactive_meds_flag'].notna().sum())

# Count hospitalizations meeting definition 2 (vasoactive meds for all 24h)
strobe_counts['E2_patients_meeting_definition_2_24h_vasoactive_meds'] = int(def_2_status['def_2'].sum())

strobe_counts

In [None]:
all_defs = all_defs.merge(def_2_status[['hospitalization_id', 'def_2']], 
                          on='hospitalization_id', 
                          how='left')
all_defs['def_2'] = all_defs['def_2'].fillna(0)
all_defs['def_2'] = all_defs['def_2'].astype(int)

# Definition III

Hospitalizations that have stage I AKI defined as 

    3a. 0.3 mg/dl absolute increase in serum creatinine over a 48 hour period since admission   
    3b. 50% increase in serum creatinine over 7 days

Notes:

- Baseline creatinine is the first recorded creatinine value since hospital admission

- The change in creatinine dose is calculated by comparing the baseline value to the max in a 48 hour time window for 3a, and a 7 day time window for 3b

- This definition is satisfied when the patient satisfies either 3a or 3b or both


#### Labs- Creatinine

In [None]:
labs_required_columns = [
    'hospitalization_id',
    'lab_result_dttm',
    'lab_name',
    'lab_category',
    'lab_value',
    'lab_value_numeric'
]
labs_of_interest = ['creatinine']

# Import labs
labs_filters = {
    'hospitalization_id': cohort_final['hospitalization_id'].unique().tolist(),
    'lab_category': labs_of_interest
}
labs = pyCLIF.load_data('clif_labs', columns=labs_required_columns, filters=labs_filters)
print("unique encounters in labs", pyCLIF.count_unique_encounters(labs))
labs['hospitalization_id']= labs['hospitalization_id'].astype(str)
labs = labs.sort_values(by=['hospitalization_id', 'lab_result_dttm'])
labs = pyCLIF.convert_datetime_columns_to_site_tz(labs, pyCLIF.helper['timezone'])
labs['lab_value_numeric'] = pd.to_numeric(labs['lab_value_numeric'], errors='coerce')
labs['recorded_hour'] = labs['lab_result_dttm'].dt.hour
labs['recorded_date'] = labs['lab_result_dttm'].dt.date

creatinine = labs[['hospitalization_id','recorded_date', 'recorded_hour', 'lab_value_numeric']]
creatinine = creatinine.rename(columns={'lab_value_numeric': 'creatinine'})
creatinine = creatinine.sort_values(by=['hospitalization_id', 'recorded_date', 'recorded_hour', 'creatinine'])

strobe_counts["E_adult_hospitalizations_since_2021_icu_creatinine"] = len(creatinine['hospitalization_id'].drop_duplicates())
strobe_counts

In [None]:
final_df = final_df.merge(
    creatinine,
    on=['hospitalization_id', 'recorded_date', 'recorded_hour'],
    how='left'
)

In [None]:
final_df.dtypes

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# AKI (Definition 3) using hospital admission anchor + window maxima
# Baseline: first creatinine at/after admission_dttm
# Maxima:   max creatinine within 0–48h and 0–7d after admission_dttm
# ──────────────────────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd

# Ensure consistent TZ & dtypes
assert 'admission_dttm' in final_df.columns, "final_df must contain admission_dttm"
assert pd.api.types.is_datetime64_any_dtype(final_df['admission_dttm']), "admission_dttm must be datetime"

# One admission row per hospitalization
admit = (final_df[['hospitalization_id','admission_dttm']]
         .drop_duplicates('hospitalization_id', keep='first'))

# Keep plausible creatinine, rename for clarity
labs_cr = (labs.loc[
              labs['lab_value_numeric'].between(0.1, 20.0, inclusive='both')
           , ['hospitalization_id','lab_result_dttm','lab_value_numeric']]
           .rename(columns={'lab_value_numeric':'creatinine'})
           .copy())

# Align labs to admission clock (hospital admission)
labs_cr = labs_cr.merge(admit, on='hospitalization_id', how='inner')

# Use only labs at/after hospital admission
labs_cr = labs_cr[labs_cr['lab_result_dttm'] >= labs_cr['admission_dttm']].copy()

# Baseline = FIRST value after admission
labs_cr.sort_values(['hospitalization_id','lab_result_dttm'], inplace=True)
baseline = (labs_cr.groupby('hospitalization_id', as_index=False)
                  .first()[['hospitalization_id','creatinine']]
                  .rename(columns={'creatinine':'baseline_creatinine'}))

# Window membership (half-open intervals)
labs_cr['in_48h'] = labs_cr['lab_result_dttm'] < (labs_cr['admission_dttm'] + pd.Timedelta(hours=48))
labs_cr['in_7d']  = labs_cr['lab_result_dttm'] < (labs_cr['admission_dttm'] + pd.Timedelta(days=7))

# Window maxima
max48 = (labs_cr.loc[labs_cr['in_48h']]
                .groupby('hospitalization_id', as_index=False)['creatinine']
                .max()
                .rename(columns={'creatinine':'peak_creatinine_48h'}))

max7  = (labs_cr.loc[labs_cr['in_7d']]
                .groupby('hospitalization_id', as_index=False)['creatinine']
                .max()
                .rename(columns={'creatinine':'peak_creatinine_7d'}))

# Assemble per-encounter AKI table
def_3_status = (admit[['hospitalization_id']]
        .merge(baseline, on='hospitalization_id', how='left')
        .merge(max48,   on='hospitalization_id', how='left')
        .merge(max7,    on='hospitalization_id', how='left'))

# Flags (NaNs -> 0)
def_3_status['def_3a'] = (
    (def_3_status['baseline_creatinine'].notna()) &
    (def_3_status['peak_creatinine_48h'].notna()) &
    ((def_3_status['peak_creatinine_48h'] - def_3_status['baseline_creatinine']) >= 0.3)
).astype(int)

def_3_status['def_3b'] = (
    (def_3_status['baseline_creatinine'].notna()) &
    (def_3_status['baseline_creatinine'] > 0) &
    (def_3_status['peak_creatinine_7d'].notna()) &
    ((def_3_status['peak_creatinine_7d'] / def_3_status['baseline_creatinine']) >= 1.5)
).astype(int)

def_3_status['def_3'] = ((def_3_status['def_3a'] == 1) | (def_3_status['def_3b'] == 1)).astype(int)

# Merge back to final_df (dedup columns if present)
final_df = final_df.drop(columns=[c for c in ['baseline_creatinine',
                                              'peak_creatinine_48h',
                                              'peak_creatinine_7d',
                                              'def_3a','def_3b','def_3']
                                  if c in final_df], errors='ignore')
final_df = final_df.merge(
    def_3_status[['hospitalization_id','baseline_creatinine',
          'peak_creatinine_48h','peak_creatinine_7d','def_3a','def_3b','def_3']],
    on='hospitalization_id', how='left'
)

# STROBE-style counts
patients_with_baseline  = int(def_3_status['baseline_creatinine'].notna().sum())
patients_with_48h_data  = int(def_3_status['peak_creatinine_48h'].notna().sum())
patients_with_7d_data   = int(def_3_status['peak_creatinine_7d'].notna().sum())
patients_meeting_def3a  = int(def_3_status['def_3a'].sum())
patients_meeting_def3b  = int(def_3_status['def_3b'].sum())
patients_meeting_def3   = int(def_3_status['def_3'].sum())

print("\n=== DEFINITION 3 RESULTS (hospital admission anchor) ===")
print(f"Patients with baseline creatinine: {patients_with_baseline}")
print(f"Patients with 48h creatinine data: {patients_with_48h_data}")
print(f"Patients with 7d creatinine data : {patients_with_7d_data}")
print(f"Meeting def_3a (0.3 mg/dL in 48h): {patients_meeting_def3a}")
print(f"Meeting def_3b (50% in 7d)      : {patients_meeting_def3b}")
print(f"Meeting def_3 (either)           : {patients_meeting_def3}")

# Update strobe counts dict (if you use it downstream)
strobe_counts["F1_patients_with_baseline_creatinine"] = patients_with_baseline
strobe_counts["F2_patients_with_48h_creatinine_data"] = patients_with_48h_data
strobe_counts["F3_patients_with_7d_creatinine_data"]  = patients_with_7d_data
strobe_counts["F4_patients_meeting_definition_3a_48h_aki"] = patients_meeting_def3a
strobe_counts["F5_patients_meeting_definition_3b_7d_aki"]  = patients_meeting_def3b
strobe_counts["F6_patients_meeting_definition_3_aki"]      = patients_meeting_def3


In [None]:
all_defs = all_defs.merge(def_3_status[['hospitalization_id', 'def_3']], 
                          on='hospitalization_id', 
                          how='left')
all_defs['def_3'] = all_defs['def_3'].fillna(0)
all_defs['def_3'] = all_defs['def_3'].astype(int)

In [None]:
# Create mortality variable from hospitalization discharge category
mortality_data = hospitalization[['hospitalization_id', 'discharge_category']].copy()

# Create mortality flag based on discharge category
mortality_data['mortality'] = mortality_data['discharge_category'].str.lower().isin(['expired', 'hospice']).astype(int)

# Join mortality data with all_defs
all_defs = all_defs.merge(
    mortality_data[['hospitalization_id', 'mortality']], 
    on='hospitalization_id', 
    how='left'
)

In [None]:
 crrt_columns = [
    'hospitalization_id', 
    'recorded_dttm',
    'crrt_mode_name',
    'crrt_mode_category',
]
 crrt_df = pyCLIF.load_data(
        'clif_crrt_therapy',
        columns=crrt_columns,
        filters={'hospitalization_id': cohort_final['hospitalization_id'].unique().tolist()}
    )

# Create CRRT flag - patients who have any CRRT record
crrt_patients = crrt_df['hospitalization_id'].unique()
crrt_flag = pd.DataFrame({
    'hospitalization_id': all_defs['hospitalization_id'].unique()
})
crrt_flag['crrt'] = crrt_flag['hospitalization_id'].isin(crrt_patients).astype(int)

# Join CRRT data with all_defs
all_defs = all_defs.merge(
    crrt_flag[['hospitalization_id', 'crrt']], 
    on='hospitalization_id', 
    how='left'
)

all_defs['crrt'] = all_defs['crrt'].fillna(0).astype(int)

# Summary

In [None]:
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_indicators
import pandas as pd
import os
import numpy as np

# Create output directory if it doesn't exist
os.makedirs('../output/final', exist_ok=True)

# Create a summary dataframe with all definitions
summary_df = all_defs[['hospitalization_id', 'def_1', 'def_2', 'def_3']].drop_duplicates()

# Fill NaN values with 0 for the definitions
summary_df = summary_df.fillna(0)

# Convert to boolean for upset plot
summary_df['def_1'] = summary_df['def_1'].astype(bool)
summary_df['def_2'] = summary_df['def_2'].astype(bool) 
summary_df['def_3'] = summary_df['def_3'].astype(bool)

# Create upset plot with better sizing
fig = plt.figure(figsize=(16, 12))  # Larger figure size
upset_data = from_indicators(['def_1', 'def_2', 'def_3'], 
                           data=summary_df.set_index('hospitalization_id'))

upset = UpSet(upset_data, 
              subset_size='count',
              show_counts=True,
              sort_by='cardinality',
              element_size=50,  # Larger dots
              with_lines=True)  # Add connecting lines for clarity

# Plot with custom spacing
upset.plot(fig=fig)

# Adjust spacing to prevent overlapping
plt.subplots_adjust(left=0.2, bottom=0.2, right=0.95, top=0.85, hspace=0.3, wspace=0.3)

# Add title with more space
plt.suptitle('Clinical Definition Intersections', 
             fontsize=16, y=0.95)

# Adjust font sizes for better readability
for ax in fig.get_axes():
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(12)

# Save the plot
plt.savefig('../output/final/definition_overlap_upset_plot.png', dpi=300, bbox_inches='tight')
plt.show()

# Create detailed summary table
combinations = []

# Individual definitions
combinations.append({
    'Combination': 'def_1 only',
    'Description': '24h ventilation only',
    'Count': int(((summary_df['def_1']) & (~summary_df['def_2']) & (~summary_df['def_3'])).sum())
})

combinations.append({
    'Combination': 'def_2 only', 
    'Description': '24h vasoactives only',
    'Count': int(((~summary_df['def_1']) & (summary_df['def_2']) & (~summary_df['def_3'])).sum())
})

combinations.append({
    'Combination': 'def_3 only',
    'Description': 'AKI only', 
    'Count': int(((~summary_df['def_1']) & (~summary_df['def_2']) & (summary_df['def_3'])).sum())
})

# Pairwise combinations
combinations.append({
    'Combination': 'def_1 & def_2',
    'Description': '24h ventilation + 24h vasoactives',
    'Count': int(((summary_df['def_1']) & (summary_df['def_2']) & (~summary_df['def_3'])).sum())
})

combinations.append({
    'Combination': 'def_1 & def_3',
    'Description': '24h ventilation + AKI',
    'Count': int(((summary_df['def_1']) & (~summary_df['def_2']) & (summary_df['def_3'])).sum())
})

combinations.append({
    'Combination': 'def_2 & def_3', 
    'Description': '24h vasoactives + AKI',
    'Count': int(((~summary_df['def_1']) & (summary_df['def_2']) & (summary_df['def_3'])).sum())
})

# All three
combinations.append({
    'Combination': 'def_1 & def_2 & def_3',
    'Description': 'All three conditions',
    'Count': int(((summary_df['def_1']) & (summary_df['def_2']) & (summary_df['def_3'])).sum())
})

# None
combinations.append({
    'Combination': 'None',
    'Description': 'No conditions met',
    'Count': int(((~summary_df['def_1']) & (~summary_df['def_2']) & (~summary_df['def_3'])).sum())
})

# Create summary table
combo_df = pd.DataFrame(combinations)
combo_df['Percentage'] = (combo_df['Count'] / len(summary_df) * 100).round(1)

print("Summary Table of Definition Combinations:")
print("=" * 60)
print(combo_df.to_string(index=False))

# Also show totals for each individual definition
individual_totals = {
    'def_1_24h_ventilation_count': int(summary_df['def_1'].sum()),
    'def_1_24h_ventilation_percentage': float(summary_df['def_1'].mean()*100),
    'def_2_24h_vasoactives_count': int(summary_df['def_2'].sum()),
    'def_2_24h_vasoactives_percentage': float(summary_df['def_2'].mean()*100),
    'def_3_aki_count': int(summary_df['def_3'].sum()),
    'def_3_aki_percentage': float(summary_df['def_3'].mean()*100),
    'total_hospitalizations': int(len(summary_df))
}

print("\n\nIndividual Definition Totals:")
print("=" * 40)
print(f"def_1 (24h ventilation): {individual_totals['def_1_24h_ventilation_count']} ({individual_totals['def_1_24h_ventilation_percentage']:.1f}%)")
print(f"def_2 (24h vasoactives): {individual_totals['def_2_24h_vasoactives_count']} ({individual_totals['def_2_24h_vasoactives_percentage']:.1f}%)")
print(f"def_3 (AKI): {individual_totals['def_3_aki_count']} ({individual_totals['def_3_aki_percentage']:.1f}%)")
print(f"Total hospitalizations: {individual_totals['total_hospitalizations']}")

# Save the summary tables
combo_df.to_csv('../output/final/definition_combinations_summary.csv', index=False)

# Save individual totals as JSON for easy reading
import json
with open('../output/final/individual_definition_totals.json', 'w') as f:
    json.dump(individual_totals, f, indent=2)

# Save the final strobe counts
with open('../output/final/strobe_counts.json', 'w') as f:
    # Convert numpy types to native Python types for JSON serialization
    strobe_counts_serializable = {}
    for key, value in strobe_counts.items():
        if hasattr(value, 'item'):  # numpy scalar
            strobe_counts_serializable[key] = value.item()
        elif isinstance(value, (np.integer, np.int64, np.int32)):
            strobe_counts_serializable[key] = int(value)
        elif isinstance(value, (np.floating, np.float64, np.float32)):
            strobe_counts_serializable[key] = float(value)
        else:
            strobe_counts_serializable[key] = value
    json.dump(strobe_counts_serializable, f, indent=2)

print(f"\n\nFiles saved to output/final/:")
print("- definition_overlap_upset_plot.png")
print("- definition_combinations_summary.csv")
print("- individual_definition_totals.json")
print("- strobe_counts.json")

In [None]:
A = summary_df['def_1'].astype(bool)
B = summary_df['def_2'].astype(bool)
C = summary_df['def_3'].astype(bool)

# EXCLUSIVE regions (match the UpSet bars)
excl = {
    '1_only':  int(( A & ~B & ~C).sum()),
    '2_only':  int((~A &  B & ~C).sum()),
    '3_only':  int((~A & ~B &  C).sum()),
    '1&2_only':int(( A &  B & ~C).sum()),
    '1&3_only':int(( A & ~B &  C).sum()),
    '2&3_only':int((~A &  B &  C).sum()),
    'all3':    int(( A &  B &  C).sum()),
    'none':    int((~A & ~B & ~C).sum())
}

# INCLUSIVE pairs (what people intuitively compare to all3)
incl = {
    '1∩2_incl': excl['1&2_only'] + excl['all3'],
    '1∩3_incl': excl['1&3_only'] + excl['all3'],
    '2∩3_incl': excl['2&3_only'] + excl['all3'],
    'all3':     excl['all3'],
    '|1|':      int(A.sum()),
    '|2|':      int(B.sum()),
    '|3|':      int(C.sum())
}

# Sanity checks (assert won’t raise if all good)
assert sum(excl.values()) == len(summary_df)             # exclusive bins partition N
assert incl['all3'] <= incl['1∩2_incl']                  # all3 ≤ each inclusive pair
assert incl['all3'] <= incl['1∩3_incl']
assert incl['all3'] <= incl['2∩3_incl']
assert excl['1&2_only'] <= incl['|1|'] and excl['1&2_only'] <= incl['|2|']


# Venn

In [None]:
from matplotlib_venn import venn3
import matplotlib.pyplot as plt
A = summary_df['def_1'].astype(bool)
B = summary_df['def_2'].astype(bool)
C = summary_df['def_3'].astype(bool)

subsets = (
    int((A & ~B & ~C).sum()),   # 100
    int((~A & B & ~C).sum()),   # 010
    int((A & B & ~C).sum()),    # 110
    int((~A & ~B & C).sum()),   # 001
    int((A & ~B & C).sum()),    # 101
    int((~A & B & C).sum()),    # 011
    int((A & B & C).sum()),     # 111
)
venn3(subsets=subsets, set_labels=('def_1: 24h vent','def_2: 24h vaso','def_3: AKI'))
plt.savefig('../output/final/venn.png', dpi=300, bbox_inches='tight')
plt.show()


# CONSORT

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch
import numpy as np

def create_consort_diagram(strobe_counts):
    """
    Create a CONSORT flow diagram using the strobe_counts data with properly spaced arrows
    """
    
    # Create figure and axis with equal margins on all sides
    fig, ax = plt.subplots(1, 1, figsize=(14, 10))
    ax.set_xlim(-1, 11)  # Extended left and right margins
    ax.set_ylim(0, 12)
    ax.axis('off')
    
    # Box styling - all white with black borders
    box_style = "round,pad=0.1"
    
    # Track all box positions for arrow drawing
    boxes = {}
    
    # Helper function to create boxes
    def create_box(x, y, width, height, text, box_id=None, fontsize=10, fontweight='normal'):
        # Create fancy box - all white
        box = FancyBboxPatch(
            (x - width/2, y - height/2), width, height,
            boxstyle=box_style,
            facecolor='white',
            edgecolor='black',
            linewidth=1.5
        )
        ax.add_patch(box)
        
        # Add text
        ax.text(x, y, text, ha='center', va='center', 
                fontsize=fontsize, fontweight=fontweight, wrap=True)
        
        # Store box boundaries
        box_info = {
            'x': x,
            'y': y,
            'width': width,
            'height': height,
            'left': x - width/2,
            'right': x + width/2,
            'top': y + height/2,
            'bottom': y - height/2
        }
        
        if box_id:
            boxes[box_id] = box_info
        
        return box_info
    
    # Helper function to create arrows with proper gaps
    def create_arrow(from_box, to_box, from_point='bottom_center', to_point='top_center', style='->', lw=2):
        # Define connection points with gaps
        gap = 0.1  # Gap between arrow and box
        
        # Calculate from coordinates
        if from_point == 'bottom_center':
            x1 = from_box['x']
            y1 = from_box['bottom'] - gap
        elif from_point == 'bottom_left':
            x1 = from_box['x'] - from_box['width'] * 0.2
            y1 = from_box['bottom'] - gap
        elif from_point == 'bottom_right':
            x1 = from_box['x'] + from_box['width'] * 0.2
            y1 = from_box['bottom'] - gap
        elif from_point == 'right':
            x1 = from_box['right'] + gap
            y1 = from_box['y']
        
        # Calculate to coordinates
        if to_point == 'top_center':
            x2 = to_box['x']
            y2 = to_box['top'] + gap
        elif to_point == 'left':
            x2 = to_box['left'] - gap
            y2 = to_box['y']
        
        ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                   arrowprops=dict(arrowstyle=style, lw=lw, color='black'))
    
    # Main title
    ax.text(5, 11.5, 'CONSORT Flow Diagram: Adult Hospitalizations Since 2021', 
            ha='center', va='center', fontsize=16, fontweight='bold')
    
    # Define consistent box height
    box_height = 0.7
    
    # Level 1: Total Adult Hospitalizations
    box1 = create_box(5, 10.3, 3, box_height, 
               f"Total Adult Hospitalizations\nn = {strobe_counts['A_adult_hospitalizations_since_2021']:,}",
               'total', fontsize=11, fontweight='bold')
    
    # Level 2: ICU Admissions - with more vertical space
    box2 = create_box(5, 8.8, 3, box_height,
               f"ICU Admissions\nn = {strobe_counts['B_adult_hospitalizations_since_2021_with_icu']:,}",
               'icu', fontsize=11, fontweight='bold')
    
    # Arrow from Level 1 to Level 2
    create_arrow(box1, box2)
    
    # Level 3: Three branches - with more vertical space
    box3_left = create_box(2, 7.2, 2.2, box_height,
               f"ICU with IMV\nn = {strobe_counts['D_adult_hospitalizations_since_2021_with_icu_imv']:,}",
               'imv', fontsize=10)
    
    box3_center = create_box(5, 7.2, 2.2, box_height,
               f"ICU with Medications\nn = {strobe_counts['E_adult_hospitalizations_since_2021_icu_meds']:,}",
               'meds', fontsize=10)
    
    box3_right = create_box(8, 7.2, 2.2, box_height,
               f"ICU with Creatinine\nn = {strobe_counts['E_adult_hospitalizations_since_2021_icu_creatinine']:,}",
               'creat', fontsize=10)
    
    # Branching arrows from ICU to three branches
    create_arrow(box2, box3_left)
    create_arrow(box2, box3_center)
    create_arrow(box2, box3_right)
    
    # Level 4: Definitions - with more vertical space
    box4_left = create_box(2, 5.5, 2.2, box_height + 0.1,
               f"Definition 1\nOn vent for first 24 hrs\nof first ICU stay\nn = {strobe_counts['D2_patients_meeting_definition_1_24h_ventilation']:,}",
               'def1', fontsize=9)
    
    box4_center = create_box(5, 5.5, 2.2, box_height + 0.1,
               f"Definition 2\nVasoactive meds in\nfirst 24h\nn = {strobe_counts['E2_patients_meeting_definition_2_24h_vasoactive_meds']:,}",
               'def2', fontsize=9)
    
    box4_right = create_box(8, 5.5, 2.2, box_height + 0.1,
               f"Definition 3\nCreatinine criteria\n(Either 3a or 3b)\nn = {strobe_counts['F6_patients_meeting_definition_3_aki']:,}",
               'def3', fontsize=9)
    
    # Arrows from Level 3 to Level 4
    create_arrow(box3_left, box4_left)
    create_arrow(box3_center, box4_center)
    create_arrow(box3_right, box4_right)
    
    # Level 5: Sub-definitions for Definition 3 - with more vertical space
    box5_left = create_box(6.5, 3.7, 2.2, box_height + 0.1,
               f"Definition 3a\n0.3 mg/dl increase\nin 48h\nn = {strobe_counts['F4_patients_meeting_definition_3a_48h_aki']:,}",
               'def3a', fontsize=9)
    
    box5_right = create_box(9.5, 3.7, 2.2, box_height + 0.1,
               f"Definition 3b\n50% increase\nin 7d\nn = {strobe_counts['F5_patients_meeting_definition_3b_7d_aki']:,}",
               'def3b', fontsize=9)
    
    # Arrows from Definition 3 to sub-definitions (split from bottom)
    create_arrow(box4_right, box5_left, from_point='bottom_left')
    create_arrow(box4_right, box5_right, from_point='bottom_right')
    
    # Add exclusion counts box - adjusted position to fit within margins
    excluded_icu = strobe_counts['A_adult_hospitalizations_since_2021'] - strobe_counts['B_adult_hospitalizations_since_2021_with_icu']
    
    exclusion_box = create_box(8.5, 10.3, 1.8, 0.5,
                               f'Excluded: No ICU\nn = {excluded_icu:,}',
                               'exclusion', fontsize=9)
    
    # Arrow to exclusion box
    create_arrow(box1, exclusion_box, from_point='right', to_point='left')
    
    plt.tight_layout()
    
    return fig, ax

# Create the diagram
fig, ax = create_consort_diagram(strobe_counts)

# Save the figure
plt.savefig('../output/final/consort_flow_diagram_clean.png', dpi=300, bbox_inches='tight', 
            facecolor='white', edgecolor='none')

plt.show()

print("Clean CONSORT Flow Diagram created and saved to:")
print("- ../output/final/consort_flow_diagram_clean.png")

# TableOne

In [None]:
import pandas as pd
import numpy as np

# ========================================
# TABLE ONE: Patient Characteristics by Definition Groups (Enhanced)
# ========================================

print("Creating Enhanced Table One...")

# Create patient-level summary with definitions from all_defs (including mortality and CRRT)
patient_summary = all_defs[['hospitalization_id', 'def_1', 'def_2', 'def_3', 'mortality', 'crrt']].copy()
patient_summary = patient_summary.fillna(0)

# Merge with hospitalization data for age and outcome
patient_demo = patient_summary.merge(
    hospitalization[['hospitalization_id', 'patient_id', 'age_at_admission']],
    on='hospitalization_id',
    how='left'
)

# Merge with patient data for sex, race, ethnicity
patient_demo = patient_demo.merge(
    patient[['patient_id', 'sex_category', 'race_category', 'ethnicity_category']],
    on='patient_id',
    how='left'
)

# Get ICU length of stay from first_icu_stays
icu_los_summary = first_icu_stays[['hospitalization_id', 'icu_duration_hours', 'icu_duration_days']].copy()
patient_demo = patient_demo.merge(icu_los_summary, on='hospitalization_id', how='left')

# Get baseline creatinine for each patient (from final_df if available)
if 'baseline_creatinine' in final_df.columns:
    baseline_creat_summary = final_df[['hospitalization_id', 'baseline_creatinine']].drop_duplicates()
    patient_demo = patient_demo.merge(baseline_creat_summary, on='hospitalization_id', how='left')
else:
    patient_demo['baseline_creatinine'] = np.nan

# Get medication flags - any vasoactive med use during stay (from final_df if available)
if 'vasoactive_meds_flag' in final_df.columns:
    med_summary = final_df[['hospitalization_id', 'vasoactive_meds_flag']].groupby('hospitalization_id').max().reset_index()
    patient_demo = patient_demo.merge(med_summary, on='hospitalization_id', how='left')
else:
    patient_demo['vasoactive_meds_flag'] = 0

# Create definition groups for comparison
def create_def_groups(row):
    if row['def_1'] == 1 and row['def_2'] == 1 and row['def_3'] == 1:
        return 'All three (def_1+2+3)'
    elif row['def_1'] == 1 and row['def_2'] == 1:
        return 'Vent + Vaso (def_1+2)'
    elif row['def_1'] == 1 and row['def_3'] == 1:
        return 'Vent + AKI (def_1+3)'
    elif row['def_2'] == 1 and row['def_3'] == 1:
        return 'Vaso + AKI (def_2+3)'
    elif row['def_1'] == 1:
        return 'Ventilation only (def_1)'
    elif row['def_2'] == 1:
        return 'Vasoactive only (def_2)'
    elif row['def_3'] == 1:
        return 'AKI only (def_3)'
    else:
        return 'No definitions'

patient_demo['definition_group'] = patient_demo.apply(create_def_groups, axis=1)

# Function to create table one statistics
def create_table_one():
    groups = ['Overall'] + sorted(patient_demo['definition_group'].unique().tolist())
    
    table_data = []
    
    for group in groups:
        if group == 'Overall':
            data = patient_demo
        else:
            data = patient_demo[patient_demo['definition_group'] == group]
        
        # Basic counts
        n_hospitalizations = len(data)
        n_patients = data['patient_id'].nunique()
        
        # Demographics - handle missing values
        age_stats = data['age_at_admission'].describe() if n_hospitalizations > 0 else pd.Series()
        
        # Sex distribution
        sex_dist = data['sex_category'].value_counts() if n_hospitalizations > 0 else pd.Series()
        
        # Race distribution
        race_dist = data['race_category'].value_counts() if n_hospitalizations > 0 else pd.Series()
        
        # Ethnicity distribution
        ethnicity_dist = data['ethnicity_category'].value_counts() if n_hospitalizations > 0 else pd.Series()
        
        # Mortality
        mortality_count = int(data['mortality'].sum()) if n_hospitalizations > 0 else 0
        
        # CRRT
        crrt_count = int(data['crrt'].sum()) if n_hospitalizations > 0 else 0
        
        # ICU Length of Stay
        icu_los_stats = data['icu_duration_days'].describe() if n_hospitalizations > 0 else pd.Series()
        
        # Medication use
        med_use = int(data['vasoactive_meds_flag'].sum()) if n_hospitalizations > 0 else 0
        
        # Creatinine distribution
        creat_stats = data['baseline_creatinine'].describe() if n_hospitalizations > 0 else pd.Series()
        
        # Store results with safe handling of missing data
        group_stats = {
            'Group': group,
            'N_Hospitalizations': n_hospitalizations,
            'N_Patients': n_patients,
            'Age_Median': age_stats.get('50%', 0) if not pd.isna(age_stats.get('50%', np.nan)) else 0,
            'Age_Q1': age_stats.get('25%', 0) if not pd.isna(age_stats.get('25%', np.nan)) else 0,
            'Age_Q3': age_stats.get('75%', 0) if not pd.isna(age_stats.get('75%', np.nan)) else 0,
            'Female_N': sex_dist.get('Female', 0),
            'Female_Pct': (sex_dist.get('Female', 0) / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'Male_N': sex_dist.get('Male', 0),
            'Male_Pct': (sex_dist.get('Male', 0) / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'White_N': race_dist.get('White', 0),
            'White_Pct': (race_dist.get('White', 0) / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'Black_N': race_dist.get('Black or African American', 0),
            'Black_Pct': (race_dist.get('Black or African American', 0) / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'Hispanic_N': ethnicity_dist.get('Hispanic', 0),
            'Hispanic_Pct': (ethnicity_dist.get('Hispanic', 0) / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'Mortality_N': mortality_count,
            'Mortality_Pct': (mortality_count / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'CRRT_N': crrt_count,
            'CRRT_Pct': (crrt_count / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'ICU_LOS_Median': icu_los_stats.get('50%', 0) if not pd.isna(icu_los_stats.get('50%', np.nan)) else 0,
            'ICU_LOS_Q1': icu_los_stats.get('25%', 0) if not pd.isna(icu_los_stats.get('25%', np.nan)) else 0,
            'ICU_LOS_Q3': icu_los_stats.get('75%', 0) if not pd.isna(icu_los_stats.get('75%', np.nan)) else 0,
            'Vasoactive_Meds_N': med_use,
            'Vasoactive_Meds_Pct': (med_use / n_hospitalizations * 100) if n_hospitalizations > 0 else 0,
            'Creatinine_Median': creat_stats.get('50%', 0) if not pd.isna(creat_stats.get('50%', np.nan)) else 0,
            'Creatinine_Q1': creat_stats.get('25%', 0) if not pd.isna(creat_stats.get('25%', np.nan)) else 0,
            'Creatinine_Q3': creat_stats.get('75%', 0) if not pd.isna(creat_stats.get('75%', np.nan)) else 0,
        }
        
        table_data.append(group_stats)
    
    return pd.DataFrame(table_data)

# Create the table
table_one = create_table_one()

# Format the table for better presentation
def format_table_one(df):
    formatted_data = []
    
    # Define the order of characteristics (added mortality, CRRT, and ICU LOS)
    characteristics = [
        'N (Hospitalizations)',
        'N (Unique Patients)', 
        'Age, median [Q1, Q3]',
        'Female, n (%)',
        'Male, n (%)',
        'White, n (%)',
        'Black, n (%)',
        'Hispanic, n (%)',
        'Mortality, n (%)',
        'CRRT, n (%)',
        'ICU Length of Stay (days), median [Q1, Q3]',
        'Vasoactive Medications, n (%)',
        'Baseline Creatinine, median [Q1, Q3]'
    ]
    
    # Create formatted table
    final_table = []
    
    for char in characteristics:
        row_dict = {'Characteristic': char}
        
        for _, row in df.iterrows():
            group_name = row['Group']
            
            if char == 'N (Hospitalizations)':
                value = f"{int(row['N_Hospitalizations'])}"
            elif char == 'N (Unique Patients)':
                value = f"{int(row['N_Patients'])}"
            elif char == 'Age, median [Q1, Q3]':
                if row['Age_Median'] > 0:
                    value = f"{row['Age_Median']:.1f} [{row['Age_Q1']:.1f}, {row['Age_Q3']:.1f}]"
                else:
                    value = "N/A"
            elif char == 'Female, n (%)':
                value = f"{int(row['Female_N'])} ({row['Female_Pct']:.1f})"
            elif char == 'Male, n (%)':
                value = f"{int(row['Male_N'])} ({row['Male_Pct']:.1f})"
            elif char == 'White, n (%)':
                value = f"{int(row['White_N'])} ({row['White_Pct']:.1f})"
            elif char == 'Black, n (%)':
                value = f"{int(row['Black_N'])} ({row['Black_Pct']:.1f})"
            elif char == 'Hispanic, n (%)':
                value = f"{int(row['Hispanic_N'])} ({row['Hispanic_Pct']:.1f})"
            elif char == 'Mortality, n (%)':
                value = f"{int(row['Mortality_N'])} ({row['Mortality_Pct']:.1f})"
            elif char == 'CRRT, n (%)':
                value = f"{int(row['CRRT_N'])} ({row['CRRT_Pct']:.1f})"
            elif char == 'ICU Length of Stay (days), median [Q1, Q3]':
                if row['ICU_LOS_Median'] > 0:
                    value = f"{row['ICU_LOS_Median']:.1f} [{row['ICU_LOS_Q1']:.1f}, {row['ICU_LOS_Q3']:.1f}]"
                else:
                    value = "N/A"
            elif char == 'Vasoactive Medications, n (%)':
                value = f"{int(row['Vasoactive_Meds_N'])} ({row['Vasoactive_Meds_Pct']:.1f})"
            elif char == 'Baseline Creatinine, median [Q1, Q3]':
                if row['Creatinine_Median'] > 0:
                    value = f"{row['Creatinine_Median']:.2f} [{row['Creatinine_Q1']:.2f}, {row['Creatinine_Q3']:.2f}]"
                else:
                    value = "N/A"
            else:
                value = "N/A"
            
            row_dict[group_name] = value
        
        final_table.append(row_dict)
    
    return pd.DataFrame(final_table)

# Create formatted table
formatted_table_one = format_table_one(table_one)

# Display the table
print("Table One: Patient Characteristics by Definition Groups (Enhanced)")
print("=" * 90)
print(formatted_table_one.to_string(index=False))

# Save the tables
import os
os.makedirs('../output/final', exist_ok=True)

table_one.to_csv('../output/final/table_one_raw_enhanced.csv', index=False)
formatted_table_one.to_csv('../output/final/table_one_formatted_enhanced.csv', index=False)

print(f"\nTables saved to:")
print("- ../output/final/table_one_raw_enhanced.csv")
print("- ../output/final/table_one_formatted_enhanced.csv")

# Create a simplified version with key comparisons
simplified_groups = ['Overall', 'No definitions', 'Ventilation only (def_1)', 
                    'Vasoactive only (def_2)', 'AKI only (def_3)', 
                    'Vent + Vaso (def_1+2)', 'All three (def_1+2+3)']

# Filter to only include columns that exist
available_groups = [col for col in simplified_groups if col in formatted_table_one.columns]
simplified_table = formatted_table_one[['Characteristic'] + available_groups]

print(f"\nSimplified Table One:")
print("=" * 70)
print(simplified_table.to_string(index=False))

simplified_table.to_csv('../output/final/table_one_simplified_enhanced.csv', index=False)

# Print summary statistics (enhanced)
print(f"\nSummary:")
print(f"Total patients analyzed: {len(patient_demo)}")
print(f"Definition groups found: {len(patient_demo['definition_group'].unique())}")

# Overall mortality rate
overall_mortality = patient_demo['mortality'].sum()
overall_mortality_rate = (overall_mortality / len(patient_demo) * 100)
print(f"Overall mortality: {int(overall_mortality)} ({overall_mortality_rate:.1f}%)")

# Overall CRRT rate
overall_crrt = patient_demo['crrt'].sum()
overall_crrt_rate = (overall_crrt / len(patient_demo) * 100)
print(f"Overall CRRT: {int(overall_crrt)} ({overall_crrt_rate:.1f}%)")

# Overall ICU LOS
overall_icu_los = patient_demo['icu_duration_days'].median()
print(f"Overall ICU LOS median: {overall_icu_los:.1f} days")

print("\nDefinition group distribution:")
group_counts = patient_demo['definition_group'].value_counts().sort_values(ascending=False)
for group, count in group_counts.items():
    pct = count / len(patient_demo) * 100
    group_mortality = patient_demo[patient_demo['definition_group'] == group]['mortality'].sum()
    group_mortality_rate = (group_mortality / count * 100) if count > 0 else 0
    group_crrt = patient_demo[patient_demo['definition_group'] == group]['crrt'].sum()
    group_crrt_rate = (group_crrt / count * 100) if count > 0 else 0
    print(f"  {group}: {count} ({pct:.1f}%) - Mortality: {int(group_mortality)} ({group_mortality_rate:.1f}%) - CRRT: {int(group_crrt)} ({group_crrt_rate:.1f}%)")