# Epidemiology of CRRT

Author: Kaveri Chhikara

This script identifies the cohort using CLIF 2.1 tables

 
                        🚨Code will break if the following requirements are not satisfied🚨  
#### Requirements

* Required table filenames should be `clif_patient`, `clif_hospitalization`, `clif_adt`, `clif_vitals`, `clif_labs`, `clif_medication_admin_continuous`, `clif_respiratory_support` ,`crrt_therapy`, `clif_hospital_diagnosis`
* Within each table, the following variables and categories are required.

| Table Name | Required Variables | Required Categories |
| --- | --- | --- |
| **clif_patient** | `patient_id`, `race_category`, `ethnicity_category`, `sex_category`, `death_dttm` | - |
| **clif_hospitalization** | `patient_id`, `hospitalization_id`, `admission_dttm`, `discharge_dttm`, `age_at_admission` | - |
| **clif_adt** |  `hospitalization_id`, `hospital_id`,`in_dttm`, `out_dttm`, `location_category`, `location_type` | - |
| **clif_vitals** | `hospitalization_id`, `recorded_dttm`, `vital_category`, `vital_value` | heart_rate, resp_rate, sbp, dbp, map, spo2, weight_kg, height_cm |
| **clif_labs** | `hospitalization_id`, `lab_result_dttm`, `lab_category`, `lab_value` | sodium, potassium, chloride, bicarbonate, bun, creatinine, glucose_serum, calcium_total, lactate, magnesium, ph_arterial, ph_venous, po2_arterial |
| **clif_medication_admin_continuous** | `hospitalization_id`, `admin_dttm`, `med_name`, `med_category`, `med_dose`, `med_dose_unit` | norepinephrine, epinephrine, phenylephrine, vasopressin, dopamine, angiotensin, dobutamine, milrinone, isoproterenol |
| **clif_respiratory_support** | `hospitalization_id`, `recorded_dttm`, `device_category`, `mode_category`, `tracheostomy`, `fio2_set`, `lpm_set`, `resp_rate_set`, `peep_set`, `resp_rate_obs`, `tidal_volume_set`, `pressure_control_set`, `pressure_support_set`, `peak_inspiratory_pressure_set`, `tidal_volume_obs` | - |
| **clif_crrt_therapy** | `hospitalization_id`, `recorded_dttm`, `crrt_mode_name`, `crrt_mode_category`, `device_id`, `blood_flow_rate`, `dialysate_flow_rate`, `ultrafilteration_out` | - |
| **clif_hospital_diagnosis** | `hospitalization_id`, `diagnosis_code`, `present_on_admission` | - |


## Setup

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
from datetime import datetime, timedelta
import json
import warnings
import pyarrow

warnings.filterwarnings('ignore')

In [None]:
import pyCLIF
import waterfall
## import outlier json
with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
    outlier_cfg = json.load(f)

In [None]:
# Initialize STROBE counts
strobe_counts = {}
# Set up output folders
output_folder = pyCLIF.setup_output_folders()

## Load Core Tables

In [None]:
print("\n=== Loading Core Tables ===")

# Load patient table
patient = pyCLIF.load_data('clif_patient')
print(f"Loaded patient table: {len(patient)} rows")

# Load hospitalization table
hospitalization = pyCLIF.load_data('clif_hospitalization')
print(f"Loaded hospitalization table: {len(hospitalization)} rows")

# Load ADT table
adt = pyCLIF.load_data('clif_adt')
print(f"Loaded ADT table: {len(adt)} rows")

# Ensure ID variables are strings
patient['patient_id'] = patient['patient_id'].astype(str)
hospitalization['hospitalization_id'] = hospitalization['hospitalization_id'].astype(str)
hospitalization['patient_id'] = hospitalization['patient_id'].astype(str)
adt['hospitalization_id'] = adt['hospitalization_id'].astype(str)

# Convert datetime columns
patient = pyCLIF.convert_datetime_columns_to_site_tz(patient, pyCLIF.helper['timezone'])
hospitalization = pyCLIF.convert_datetime_columns_to_site_tz(hospitalization, pyCLIF.helper['timezone'])
adt = pyCLIF.convert_datetime_columns_to_site_tz(adt, pyCLIF.helper['timezone'])

# Remove duplicates
patient = pyCLIF.remove_duplicates(patient, ['patient_id'], 'patient')
hospitalization = pyCLIF.remove_duplicates(hospitalization, ['hospitalization_id'], 'hospitalization')
adt = pyCLIF.remove_duplicates(adt, ['hospitalization_id', 'hospital_id', 'in_dttm'], 'adt')

# Cohort Identification

1. Adults
2. Admitted between January 1, 2018 to December, 31, 2024
2. Receiving CRRT

#### (A) Date and Age filters

In [None]:
# Date filter: 2018-01-01 to 2024-12-31
date_mask = (
    (hospitalization['admission_dttm'] >= '2018-01-01') & 
    (hospitalization['admission_dttm'] <= '2024-12-31')
)

# Age filter: adults (>= 18)
age_mask = (hospitalization['age_at_admission'] >= 18)

# Apply filters based on site
if pyCLIF.helper['site_name'].lower() == 'mimic':
    print("skipping date filter for MIMIC")
    # MIMIC doesn't have date restrictions
    hospitalization_cohort = hospitalization[age_mask].copy()
else:
    hospitalization_cohort = hospitalization[date_mask & age_mask].copy()

strobe_counts['A_after_date_age_filter'] = hospitalization_cohort['hospitalization_id'].nunique()
print(f"Hospitalizations after date & age filter: {strobe_counts['A_after_date_age_filter']}")

# Also track total adult hospitalizations
total_adult = hospitalization[age_mask]['hospitalization_id'].nunique()
strobe_counts['A_total_adult_hospitalizations'] = total_adult
print(f"Total adult hospitalizations (no date filter): {total_adult}")

In [None]:
# Get hospitalization IDs from cohort
cohort_ids = hospitalization_cohort['hospitalization_id'].unique().tolist()

#### (B) Stitch hospitalizations

Combine multiple `hospitalization_ids` into a single `encounter_block` for patients who transfer between hospital campuses or return soon after discharge. Hospitalizations that have a gap of **6 hours or less** between the discharge dttm and admission dttm are put in one encounter block.

In [None]:
adt_cohort = adt[adt['hospitalization_id'].isin(cohort_ids)].copy()

# Check for missing values in admission and discharge dates
print("\nMissing values in admission_dttm:", hospitalization_cohort['admission_dttm'].isna().sum())
print("Missing values in discharge_dttm:", hospitalization_cohort['discharge_dttm'].isna().sum())

In [None]:
# STEP B: Stitch Encounters => 'encounter_block'
# Use stitch_encounters from pyCLIF with time_interval=6
print("\n=== STEP B: Stitch encounters ===\n")
stitched_cohort = pyCLIF.stitch_encounters(hospitalization_cohort, adt_cohort, time_interval=6)

In [None]:
# stitched_cohort now has: 'patient_id','hospitalization_id','encounter_block', discharge category and other ADT variables. This will have duplicate rows because of location category
# We only want 1 row per unique encounter_block for the next steps.
stitched_unique = stitched_cohort[['patient_id', 'encounter_block']].drop_duplicates()

strobe_counts['B_before_stitching'] = stitched_cohort['hospitalization_id'].nunique()
strobe_counts['B_after_stitching'] = stitched_unique['encounter_block'].nunique()
strobe_counts['B_stitched_hosp_ids'] = strobe_counts['B_before_stitching']-strobe_counts['B_after_stitching']
print(f"Number of unique hospitalizations before stitching: {stitched_cohort['hospitalization_id'].nunique()}")
print(f"Number of unique encounter blocks after stitching: {strobe_counts['B_after_stitching']}")
print(f"Number of linked hospitalization ids: {strobe_counts['B_before_stitching']-strobe_counts['B_after_stitching']}")

In [None]:
# Mapping of patient id, hospitalization id and encounter blocks
all_ids = stitched_cohort[['patient_id', 'hospitalization_id', 'encounter_block', 'admission_dttm',
                            'discharge_dttm', 'age_at_admission','admission_type_category','discharge_category']].drop_duplicates()
print("\nUnique values in each column:")
for col in all_ids.columns[:3]:
    print(f"\n{col}:")
    print(all_ids[col].nunique())

#### (C) Identify CRRT patients

In [None]:
# Load CRRT therapy table for these hospitalizations
crrt_columns = [
    'hospitalization_id', 
    'recorded_dttm',
    'crrt_mode_name',
    'crrt_mode_category',
    'blood_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate',
    'dialysate_flow_rate',
    'ultrafiltration_out'
]

try:
    crrt_df = pyCLIF.load_data(
        'clif_crrt_therapy',
        columns=crrt_columns,
        filters={'hospitalization_id': cohort_ids}
    )
    print(f"Loaded CRRT therapy data: {len(crrt_df)} rows")
    # Ensure hospitalization_id is string
    crrt_df['hospitalization_id'] = crrt_df['hospitalization_id'].astype(str)
    # Sort by hospitalization_id and recorded_dttm
    crrt_df = crrt_df.sort_values(['hospitalization_id', 'recorded_dttm'])
    # Get unique hospitalizations with CRRT
    hosp_ids_with_crrt = crrt_df['hospitalization_id'].unique()
    strobe_counts['C_hospitalizations_with_crrt'] = len(hosp_ids_with_crrt)
    print(f"Hospitalizations with CRRT: {len(hosp_ids_with_crrt)}")
    
    # Summary of CRRT modes
    if 'crrt_mode_category' in crrt_df.columns:
        mode_counts = crrt_df['crrt_mode_category'].value_counts()
        print("\nCRRT Mode Distribution (N rows):")
        for mode, count in mode_counts.items():
            print(f"  {mode}: {count}")
except Exception as e:
    print(f"Error loading CRRT therapy data: {e}")
    raise SystemExit("Stopping execution due to error loading CRRT therapy data")

In [None]:
# Create a summary table for each  crrt_mode_category
numeric_columns = [
    'blood_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate',
    'dialysate_flow_rate',
    'ultrafiltration_out'
]
# Ensure numeric columns are properly converted
for col in numeric_columns:
    crrt_df[col] = pd.to_numeric(crrt_df[col], errors='coerce')

# Convert recorded_dttm to datetime
crrt_df['recorded_dttm'] = pd.to_datetime(crrt_df['recorded_dttm'], errors='coerce')

# Create summary for each numeric column by  crrt_mode_category
summary_list = []

for col in numeric_columns:
    summary =crrt_df.groupby('crrt_mode_category')[col].agg([
        ('total_N', 'size'),
        ('min', 'min'),
        ('max', 'max'),
        ('first_quantile', lambda x:x.quantile(0.25)),
        ('median', lambda x: x.quantile(0.5)),
        ('third_quantile', lambda x: x.quantile(0.75)),
        ('missing_values', lambda x: x.isna().sum()),
        ('mean', 'mean'),
        ('std', 'std')
    ]).reset_index()
    summary['parameter'] = col
    summary_list.append(summary)

# Combine all summaries
summary_crrt_by_mode = pd.concat(summary_list,
ignore_index=True)

# Reorder columns for better readability
column_order = ['crrt_mode_category', 'parameter',
'total_N', 'missing_values',
                'min', 'first_quantile', 'median',
'third_quantile', 'max',
                'mean', 'std']
summary_crrt_by_mode =summary_crrt_by_mode[column_order]

# Save to CSV
summary_crrt_by_mode.to_csv('../output/final/summary_crrt_by_mode_category.csv', index=False)

# Also create a simpler summary showing just the count of each mode
mode_counts = crrt_df['crrt_mode_category'].value_counts().reset_index()
mode_counts.columns = ['crrt_mode_category','count']
mode_counts.to_csv('../output/final/crrt_mode_category_counts.csv', index=False)

##### CRRT Waterfall Algorithm

Numeric columns tracked:
• blood_flow_rate
• pre_filter_replacement_fluid_rate 
• post_filter_replacement_fluid_rate
• dialysate_flow_rate
• ultrafiltration_out

Waterfall algorithm steps:

1. Fill crrt_mode_category based on non-missing numerical columns:

   * If all numeric columns are non-NA → CVVHDF
   * If only dialysate_flow_rate is NA → CVVHD
   * If pre_filter_replacement_fluid_rate and post_filter_replacement_fluid_rate are NA →  CVVHD
   * If pre_filter_replacement_fluid_rate, post_filter_replacement_fluid_rate and dialysate_flow_rate are NA →  SCUF
   * If SCUF sandwiched between two other modes, then check if dialysate_flow_rate == 0, blood_flow_rate>0, ultrafiltration_out>0 → True SCUF otherwise, update it to the previous crrt_mode_category

2. Fill forward the crrt_mode_category column- only if the time gap is 3h or less across the missing rows

3. Create episode IDs:

   * Episode definition: One continuous run of the same CRRT setup on same patient
   * Gap threshold: Maximum allowed gap within an episode (e.g. "2h", "3h", "90min")
     before starting new episode. default = "3h"
   * a new `crrt_episode_id` starts whenever  
       • `crrt_mode_category` changes **OR**  
       • the gap between successive rows exceeds *gap_thresh* (default 2 h).

4. Fill forward numeric variables within each episode ID

5. Replace invalid parameters with NA based on CRRT mode. Assuming someone inputted wrong modality. Infer modality
   (Note: This step needs verification)


In [None]:
import importlib
importlib.reload(waterfall)
import waterfall
processed_crrt_df = waterfall.process_crrt_waterfall(crrt_df, 
                                                     id_col = "hospitalization_id",
                                                     gap_thresh="3h", 
                                                     fix_islands = True,
                                                     wipe_unused=True,
                                                     verbose = True)


In [None]:
# Create a summary table for each  crrt_mode_category after applying waterfall
# First, identify the numeric columns to summarize
numeric_columns = [
    'blood_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate',
    'dialysate_flow_rate',
    'ultrafiltration_out'
]

# Create summary for each numeric column by  crrt_mode_category
summary_list = []

for col in numeric_columns:
    summary =processed_crrt_df.groupby('crrt_mode_category')[col].agg([
        ('total_N', 'size'),
        ('min', 'min'),
        ('max', 'max'),
        ('first_quantile', lambda x:x.quantile(0.25)),
        ('median', lambda x: x.quantile(0.5)),
        ('third_quantile', lambda x: x.quantile(0.75)),
        ('missing_values', lambda x: x.isna().sum()),
        ('mean', 'mean'),
        ('std', 'std')
    ]).reset_index()
    summary['parameter'] = col
    summary_list.append(summary)

# Combine all summaries
summary_crrt_by_mode = pd.concat(summary_list, ignore_index=True)

# Reorder columns for better readability
column_order = ['crrt_mode_category', 'parameter',
'total_N', 'missing_values',
                'min', 'first_quantile', 'median',
'third_quantile', 'max',
                'mean', 'std']
summary_crrt_by_mode =summary_crrt_by_mode[column_order]

# Save to CSV
summary_crrt_by_mode.to_csv('../output/final/summary_crrt_by_mode_category_post_waterfall.csv', index=False)

# Also create a simpler summary showing just the count of each mode
mode_counts = crrt_df['crrt_mode_category'].value_counts().reset_index()
mode_counts.columns = ['crrt_mode_category','count']
mode_counts.to_csv('../output/final/crrt_mode_category_counts_post_waterfall.csv', index=False)

In [None]:
# processed_crrt_df.drop(columns=[#'crrt_episode_id', 
#                                'blood_flow_missing_after_ffill',
#                                'post_filter_replacement_fluid_rate_unexpected',
#                                'pre_filter_replacement_fluid_rate_unexpected', 
#                                'dialysate_flow_rate_unexpected'
#                                ],
#                       inplace=True)

In [None]:
## CRRT units check
unit_expectations = {
    "blood_flow_rate":                   {"unit": "mL/hr", "min": 1000, "max": 20000},
    "pre_filter_replacement_fluid_rate": {"unit": "mL/hr", "min":    0, "max": 10000},
    "post_filter_replacement_fluid_rate":{"unit": "mL/hr", "min":    0, "max": 10000},
    "dialysate_flow_rate":               {"unit": "mL/hr", "min":  500, "max": 10000},
    "ultrafiltration_out":               {"unit": "mL/hr", "min":    0, "max":  5000},
}

def check_crrt_units(df, expectations, out_folder):
    os.makedirs(os.path.join(out_folder, "final"), exist_ok=True)
    log_path = os.path.join(out_folder, "final", "crrt_unit_warnings.txt")
    # clear previous log
    open(log_path, "w").close()

    for var, spec in expectations.items():
        if var not in df:
            msg = f"{var!r} not found in CRRT dataframe, skipping."
        else:
            ser = df[var].dropna()
            if ser.empty:
                msg = f"{var!r} has no non-null values, skipping."
            else:
                med = ser.median()
                lo, hi = spec["min"], spec["max"]
                if lo <= med <= hi:
                    msg = f"{var!r} median={med:.1f} within expected [{lo}-{hi}] {spec['unit']}."
                else:
                    msg = (
                        f"{var!r} median={med:.1f} outside expected "
                        f"[{lo}-{hi}] {spec['unit']}—please inspect units."
                    )
        print(msg)
        with open(log_path, "a") as f:
            f.write(msg + "\n")

# ─────────────────────────────────────────────────────────────────────────────
# Usage: right after crrt_df is loaded
# ─────────────────────────────────────────────────────────────────────────────
check_crrt_units(processed_crrt_df, unit_expectations, output_folder)

In [None]:
processed_crrt_df = pyCLIF.convert_datetime_columns_to_site_tz(processed_crrt_df, pyCLIF.helper['timezone'])
crrt_stitched = processed_crrt_df.merge(all_ids[['hospitalization_id',
                                           'encounter_block']],
                                  on='hospitalization_id', how='left')

In [None]:
# Keep only hospitalization IDs that are in crrt_stitched
all_ids = all_ids[all_ids['hospitalization_id'].isin(crrt_stitched['hospitalization_id'])]

In [None]:
print("blocks: ", pyCLIF.count_unique_encounters(all_ids, 'encounter_block'))
print("hosps: ", pyCLIF.count_unique_encounters(all_ids, 'hospitalization_id'))

In [None]:
strobe_counts['C_blocks_with_crrt'] = len(crrt_stitched['encounter_block'].unique())

In [None]:
strobe_counts

#### (D) Exclude patients on ESRD 

Prior to admission ICD codes for ESRD

In [None]:
# Load diagnoses data for our cohort
# diagnoses_columns = ['hospitalization_id', 'diagnosis_code', 'diagnosis_code_type'] 
hospital_diagnosis = pyCLIF.load_data(
    'clif_hospital_diagnosis',
    filters={'hospitalization_id': list(all_ids['hospitalization_id'])}
)

In [None]:
hospital_diagnosis_stitched = hospital_diagnosis.merge(all_ids[['hospitalization_id',
                                           'encounter_block']],
                                  on='hospitalization_id', how='left')
hospital_diagnosis_stitched['diagnosis_code'] = hospital_diagnosis_stitched['diagnosis_code'].str.replace('.', '')

In [None]:
# Check present_on_admission column type and standardize to int8
if 'present_on_admission' in hospital_diagnosis_stitched.columns:
    # Convert to string first to handle any data type
    hospital_diagnosis_stitched['present_on_admission'] = hospital_diagnosis_stitched['present_on_admission'].astype(str)
    
    # Map various possible values to 1/0
    hospital_diagnosis_stitched['present_on_admission'] = (
        hospital_diagnosis_stitched['present_on_admission']
        .str.lower()
        .map({'yes': 1, 'y': 1, 'true': 1, '1': 1, 'no': 0, 'n': 0, 'false': 0, '0': 0})
        .fillna(0)  # Fill any unmapped values with 0
        .astype(np.int8)
    )
hospital_diagnosis_stitched.dtypes

In [None]:
# Define ESRD diagnosis codes
# Let's debug why we're not finding ESRD codes
esrd_codes = [
    'Z992',    # Dependence on renal dialysis
    'Z9115',   # Patient's noncompliance with renal dialysis
    'I120',    # Hypertensive chronic kidney disease with stage 5 CKD or ESRD
    'N186',    # End stage renal disease
    'I132',    # Hypertensive heart and chronic kidney disease with heart failure and ESRD
    'Z992',    # Dependence on renal dialysis (alternate code)
    'N186',    # End stage renal disease (alternate code)
    'I120',    # Hypertensive chronic kidney disease with stage 5 CKD or ESRD (alternate code)
    'Z91158',  # Patient's noncompliance with renal dialysis (alternate code)
    'I1311',   # Hypertensive heart and chronic kidney disease with heart failure and stage 5 CKD
    'I132',    # Hypertensive heart and chronic kidney disease with ESRD (alternate code)
    '5856',     #ICD9 :End stage renal disease
    '40391',    #ICD9: Hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage V or end stage renal disease
    '40311',     #ICD9: Hypertensive chronic kidney disease, benign, with chronic kidney disease stage V or end stage renal disease
    'V4511',     #ICD9: Renal dialysis status
    'V4512'     #ICD9: Noncompliance with renal dialysis
]

# Get hospitalization IDs with ESRD diagnoses and print debug info
# print("Unique diagnosis codes in data:", hospital_diagnosis_stitched['diagnosis_code'].unique()[:20], "...")
# print("\nNumber of rows matching ESRD codes:", hospital_diagnosis_stitched['diagnosis_code'].isin(esrd_codes).sum())
# print("\nSample of matching rows:")
# print(hospital_diagnosis_stitched[hospital_diagnosis_stitched['diagnosis_code'].isin(esrd_codes)].head())

# Check if present_on_admission column exists and has valid values

# Check if present_on_admission exists and has valid values
if ('present_on_admission' in hospital_diagnosis_stitched.columns and 
    not hospital_diagnosis_stitched['present_on_admission'].isna().all()):
    # Get hospitalizations where ESRD was present on admission
    esrd_mask = (
        hospital_diagnosis_stitched['diagnosis_code'].isin(esrd_codes) & 
        (hospital_diagnosis_stitched['present_on_admission'] == 1)
    )
    hosp_ids_with_esrd = hospital_diagnosis_stitched[esrd_mask]['hospitalization_id'].unique()
    blocks_with_esrd = hospital_diagnosis_stitched[esrd_mask]['encounter_block'].unique()
else:
    # If no present_on_admission info or all NAs, use all ESRD diagnoses
    hosp_ids_with_esrd = hospital_diagnosis_stitched[hospital_diagnosis_stitched['diagnosis_code'].isin(esrd_codes)]['hospitalization_id'].unique()
    blocks_with_esrd = hospital_diagnosis_stitched[hospital_diagnosis_stitched['diagnosis_code'].isin(esrd_codes)]['encounter_block'].unique()



strobe_counts['D_hospitalizations_with_esrd'] = len(hosp_ids_with_esrd)
strobe_counts['D_encounter_blocks_with_esrd'] = len(hosp_ids_with_esrd)

# Create cohort subset excluding hospitalizations with ESRD
all_ids = all_ids[~all_ids['hospitalization_id'].isin(hosp_ids_with_esrd)]
crrt_stitched = crrt_stitched[~crrt_stitched['hospitalization_id'].isin(hosp_ids_with_esrd)]
strobe_counts['D_encounter_blocks_without_esrd'] = len(all_ids['encounter_block'].unique())  # Count blocks without ESRD
strobe_counts['D_hospitalizations_without_esrd'] = len(all_ids['hospitalization_id'].unique())  # Count hospitalizations without ESRD

strobe_counts

In [None]:
# AKI Codes Sanity check

# Define AKI ICD-10 codes
aki_codes = [
    # ICD-10 codes for acute kidney injury
    'N170', 'N171', 'N172', 'N178', 'N179',  # Acute kidney failure codes
    'R34',   # Anuria and oliguria
    'N990', # Post-procedural kidney failure
    'T795',  # Traumatic anuria
    '5845',  # ICD9 Acute kidney failure with lesion of tubular necrosis
    '5849',  # ICD9- Acute kidney failure, unspecified
    "5848"    # ICD9 - Acute kidney failure with other specified pathological lesion in kidney
]

# Create mask for AKI diagnoses and filter to non-ESRD encounters
aki_mask = hospital_diagnosis_stitched['diagnosis_code'].isin(aki_codes)
non_esrd_encounters = hospital_diagnosis_stitched[hospital_diagnosis_stitched['encounter_block'].isin(all_ids['encounter_block'])]

# Get encounter blocks with AKI diagnoses
blocks_with_aki = non_esrd_encounters[aki_mask]['encounter_block'].unique()
total_non_esrd_blocks = all_ids['encounter_block'].nunique()
strobe_counts['D_encounter_blocks_with_AKI_no_esrd'] = len(blocks_with_aki) 
# Calculate percentage
aki_percentage = (len(blocks_with_aki) / total_non_esrd_blocks) * 100

print(f"\nPercentage of non-ESRD encounter blocks with AKI codes: {aki_percentage:.1f}%")
print(f"({len(blocks_with_aki)} out of {total_non_esrd_blocks} blocks)")

# Show sample of AKI diagnoses
aki_diagnoses = non_esrd_encounters[aki_mask][['hospitalization_id', 'diagnosis_code','diagnosis_code_format', 'diagnosis_name', 'present_on_admission']].drop_duplicates()
print("\nSample of AKI-related diagnoses found: check aki_diagnoses.head()")

#### Sanity Check: ADT~  ICU admissions

CRRT is typically administered in the ICU setting due to the need for continuous monitoring, specialized nursing care, and close medical supervision. This section validates that our cohort consists of ICU admissions.

In [None]:
# Filter ADT data to only include hospitalizations in all_ids
adt_final = adt[adt['hospitalization_id'].isin(all_ids['hospitalization_id'])].copy()
print("unique encounters in adt_final", pyCLIF.count_unique_encounters(adt_final))
adt_final['hospitalization_id'] = adt_final['hospitalization_id'].astype(str)
adt_final = pyCLIF.convert_datetime_columns_to_site_tz(adt_final, pyCLIF.helper['timezone'])
adt_final_stitched = adt_final.merge(all_ids[['hospitalization_id', 'encounter_block']], 
                                     on='hospitalization_id', how='left')
adt_final_stitched = adt_final_stitched.sort_values(by=['encounter_block', 'in_dttm'])
desired_order = ['hospitalization_id', 'encounter_block', 'hospital_id', 'in_dttm', 'out_dttm']
remaining_cols = [col for col in adt_final_stitched.columns if col not in desired_order]
adt_final_stitched = adt_final_stitched[desired_order + remaining_cols]

In [None]:
print("blocks: ", pyCLIF.count_unique_encounters(adt_final_stitched, 'encounter_block'))
print("hosps: ", pyCLIF.count_unique_encounters(adt_final_stitched, 'hospitalization_id'))

In [None]:
print("\n=== Validating ICU Administration ===")

adt_final_stitched['is_icu'] = adt_final_stitched['location_category'] == 'icu'

# Check if each hospitalization had at least one ICU stay
hosp_icu_status = adt_final_stitched.groupby('encounter_block')['is_icu'].any()
non_icu_hosps = hosp_icu_status[~hosp_icu_status].index.tolist()
strobe_counts["D1_number_hosp_without_ICU_stay"] = len(non_icu_hosps)
print(f"\nNumber of CRRT hospitalizations without any ICU stay: {len(non_icu_hosps)}")
if len(non_icu_hosps) > 0:
    print("WARNING: Found CRRT hospitalizations without ICU stays")
    print("First few hospitalization IDs without ICU stays:", non_icu_hosps[:5])
else:
    print("All CRRT hospitalizations had at least one ICU stay")

In [None]:
# Filter patient data to required columns and join with all_ids
patient_filtered = patient[['patient_id', 'race_category', 'ethnicity_category', 'sex_category', 'death_dttm', 'language_category']]
all_ids = all_ids.merge(patient_filtered, on='patient_id', how='left')

## (E) Exclude encounters who died close to CRRT start

Exclude patients who died 6 hours or less after starting CRRT as data for these patients is likely to be skewed towards the extremely sick who were unlikely to ever recover

In [None]:
# Create mortality column based on discharge category
all_ids['discharge_category'] = all_ids['discharge_category'].str.lower()
all_ids['mortality'] = (all_ids['discharge_category'].isin(['expired', 'hospice'])).astype(np.int8)

# Create death_dttm_proxy
all_ids['death_dttm_proxy'] = all_ids['death_dttm']

# If death_dttm is less than discharge_dttm, use discharge_dttm
mask = (all_ids['death_dttm'] < all_ids['discharge_dttm'])
all_ids.loc[mask, 'death_dttm_proxy'] = all_ids.loc[mask, 'discharge_dttm']

# If death_dttm is missing and mortality=1, use discharge_dttm
mask = (all_ids['death_dttm'].isna() & (all_ids['mortality'] == 1))
all_ids.loc[mask, 'death_dttm_proxy'] = all_ids.loc[mask, 'discharge_dttm']

In [None]:
# Get first CRRT time for each encounter block
first_crrt = (
    crrt_stitched
    .groupby("encounter_block", as_index=False)["recorded_dttm"]
    .min()
    .rename(columns={"recorded_dttm": "first_crrt_time"})
)

# Merge first CRRT time with all_ids to get death times
early_death_df = (
    first_crrt.merge(
        all_ids[['encounter_block', 'death_dttm_proxy', 'mortality']], 
        on='encounter_block',
        how='left'
    )
)

# Calculate time from CRRT start to death
early_death_df['time_to_death'] = early_death_df['death_dttm_proxy'] - early_death_df['first_crrt_time']

# Identify encounters where death occurred within 6 hours of CRRT start
early_death_encounters = early_death_df[
    (early_death_df['mortality'] == 1) & 
    (early_death_df['time_to_death'] <= pd.Timedelta(hours=6))
]['encounter_block'].tolist()

# Update metrics dictionary
strobe_counts['E_encounters_early_death'] = len(early_death_encounters)

# Remove early death encounters from cohort
all_ids = all_ids[~all_ids['encounter_block'].isin(early_death_encounters)]
strobe_counts['E_encounters_after_early_death_exclusion'] = len(all_ids)

In [None]:
all_ids.shape

In [None]:
strobe_counts

# Labs

In [None]:
# Import labs
labs_required_columns = [
    'hospitalization_id',
    'lab_result_dttm',
    'lab_category',
    'lab_value',
    'lab_value_numeric'
]

labs_of_interest = [
    'sodium',
    'potassium', 
    'chloride',
    'bicarbonate',
    'bun',
    'creatinine',
    'glucose_serum',
    'calcium_total',
    'lactate',
    'magnesium',
    'ph_arterial',
    'ph_venous',
    'po2_arterial'
]


labs_filters = {
    'hospitalization_id': crrt_df['hospitalization_id'].unique().tolist(),
    'lab_category': labs_of_interest
}
labs = pyCLIF.load_data('clif_labs', columns=labs_required_columns, filters=labs_filters)
print("unique encounters in labs", pyCLIF.count_unique_encounters(labs))
labs['hospitalization_id']= labs['hospitalization_id'].astype(str)
labs = labs.merge(all_ids[['hospitalization_id', 'encounter_block']], 
                  on='hospitalization_id', how='inner')
labs = pyCLIF.convert_datetime_columns_to_site_tz(labs, pyCLIF.helper['timezone'])
labs['lab_value_numeric'] = pd.to_numeric(labs['lab_value_numeric'], errors='coerce')
labs = labs.sort_values(by=['encounter_block', 'lab_result_dttm'])

In [None]:
# Pivot labs data to get lab values as columns
labs_pivoted = labs.pivot_table(
    index=['encounter_block', 'lab_result_dttm'],
    columns='lab_category',
    values='lab_value_numeric'
).reset_index().rename(columns={'lab_result_dttm': 'recorded_dttm'})

print("\nShape of pivoted labs data:", labs_pivoted.shape)
print("\nColumns in pivoted labs data:")
print(labs_pivoted.columns.tolist())

# Vitals

In [None]:
vitals_required_columns = [
    'hospitalization_id',
    'recorded_dttm',
    'vital_category',
    'vital_value'
]
vitals_of_interest = ['weight_kg', 'height_cm', 
                      'heart_rate', 'respiratory_rate', 'sbp', 'dbp', 'map', 'spo2']

vitals_cohort = pyCLIF.load_data('clif_vitals',
    columns=vitals_required_columns,
    filters={'hospitalization_id': all_ids['hospitalization_id'].unique().tolist(), 
             'vital_category': vitals_of_interest}
)
vitals_cohort = pyCLIF.convert_datetime_columns_to_site_tz(vitals_cohort, pyCLIF.helper['timezone'])
vitals_cohort['vital_value'] = pd.to_numeric(vitals_cohort['vital_value'], errors='coerce')
# sort vitals cohort by hospitalization_id and recorded_dttm
vitals_cohort = vitals_cohort.sort_values(['hospitalization_id', 'recorded_dttm'])

# Replace outliers with NAs in the vitals table 
# Extract min/max values from config for each vital
min_hr, max_hr = outlier_cfg['heart_rate']
min_rr, max_rr = outlier_cfg['respiratory_rate'] 
min_sbp, max_sbp = outlier_cfg['sbp']
min_dbp, max_dbp = outlier_cfg['dbp']
min_map, max_map = outlier_cfg['map']
min_spo2, max_spo2 = outlier_cfg['spo2']
min_weight, max_weight = outlier_cfg['weight_kg']
min_height, max_height = outlier_cfg['height_cm']

# For each vital category, set out-of-range values to NaN
is_hr = vitals_cohort['vital_category'] == 'heart_rate'
vitals_cohort.loc[is_hr & (vitals_cohort['vital_value'] < min_hr), 'vital_value'] = np.nan
vitals_cohort.loc[is_hr & (vitals_cohort['vital_value'] > max_hr), 'vital_value'] = np.nan

is_rr = vitals_cohort['vital_category'] == 'respiratory_rate'
vitals_cohort.loc[is_rr & (vitals_cohort['vital_value'] < min_rr), 'vital_value'] = np.nan
vitals_cohort.loc[is_rr & (vitals_cohort['vital_value'] > max_rr), 'vital_value'] = np.nan

is_sbp = vitals_cohort['vital_category'] == 'sbp'
vitals_cohort.loc[is_sbp & (vitals_cohort['vital_value'] < min_sbp), 'vital_value'] = np.nan
vitals_cohort.loc[is_sbp & (vitals_cohort['vital_value'] > max_sbp), 'vital_value'] = np.nan

is_dbp = vitals_cohort['vital_category'] == 'dbp'
vitals_cohort.loc[is_dbp & (vitals_cohort['vital_value'] < min_dbp), 'vital_value'] = np.nan
vitals_cohort.loc[is_dbp & (vitals_cohort['vital_value'] > max_dbp), 'vital_value'] = np.nan

is_map = vitals_cohort['vital_category'] == 'map'
vitals_cohort.loc[is_map & (vitals_cohort['vital_value'] < min_map), 'vital_value'] = np.nan
vitals_cohort.loc[is_map & (vitals_cohort['vital_value'] > max_map), 'vital_value'] = np.nan

is_spo2 = vitals_cohort['vital_category'] == 'spo2'
vitals_cohort.loc[is_spo2 & (vitals_cohort['vital_value'] < min_spo2), 'vital_value'] = np.nan
vitals_cohort.loc[is_spo2 & (vitals_cohort['vital_value'] > max_spo2), 'vital_value'] = np.nan

is_weight = vitals_cohort['vital_category'] == 'weight_kg'
vitals_cohort.loc[is_weight & (vitals_cohort['vital_value'] < min_weight), 'vital_value'] = np.nan
vitals_cohort.loc[is_weight & (vitals_cohort['vital_value'] > max_weight), 'vital_value'] = np.nan

is_height = vitals_cohort['vital_category'] == 'height_cm'
vitals_cohort.loc[is_height & (vitals_cohort['vital_value'] < min_height), 'vital_value'] = np.nan
vitals_cohort.loc[is_height & (vitals_cohort['vital_value'] > max_height), 'vital_value'] = np.nan

vitals_cohort = vitals_cohort.merge(all_ids[['hospitalization_id', 'encounter_block']], on='hospitalization_id', how='right')

In [None]:
# Check for duplicates in the index columns
duplicates = vitals_cohort.groupby(['encounter_block', 'recorded_dttm', 'vital_category']).size().reset_index(name='count')
print("Number of duplicate rows before deduplication:", len(duplicates[duplicates['count'] > 1]))

In [None]:
# Keep only the last value for any duplicates
vitals_cohort = vitals_cohort.sort_values(['encounter_block', 
                                           'recorded_dttm', 'vital_category']).drop_duplicates(
    ['encounter_block', 'recorded_dttm', 'vital_category'], 
    keep='last'
)

# Verify no duplicates remain
duplicates = vitals_cohort.groupby(['encounter_block', 'recorded_dttm', 'vital_category']).size().reset_index(name='count')
print("Number of duplicate rows after deduplication:", len(duplicates[duplicates['count'] > 1]))

In [None]:
# Get first and last vital recorded times for each encounter block
vital_times = vitals_cohort.groupby('encounter_block')['recorded_dttm'].agg(['first', 'last']).reset_index()
vital_times = vital_times.rename(columns={
    'first': 'first_vital_dttm',
    'last': 'last_vital_dttm'
})
# Join vital times with all_ids to get vital times for all encounter blocks
all_ids = all_ids.merge(vital_times, on='encounter_block', how='left')

In [None]:
# pivot 
vitals_filtered = vitals_cohort[vitals_cohort['vital_category'].isin(vitals_of_interest)]

# Pivot the dataframe
vitals_pivoted = vitals_filtered.pivot(
    index=['encounter_block', 'recorded_dttm'],
    columns='vital_category',
    values='vital_value'
).reset_index()

In [None]:
# Join vitals_pivoted with labs_pivoted
clif_wide = vitals_pivoted.merge(
    labs_pivoted,
    on=['encounter_block', 'recorded_dttm'],
    how='outer'  )

In [None]:
# Get first CRRT time for each encounter_block
vitals_bmi = crrt_stitched.groupby('encounter_block')['recorded_dttm'].first().reset_index()

# Create separate dataframes for weight and height
weight_df = vitals_cohort[vitals_cohort['vital_category'] == 'weight_kg'].copy()
height_df = vitals_cohort[vitals_cohort['vital_category'] == 'height_cm'].copy()

# Function to find closest vital measurement to CRRT start
def get_closest_vital(vital_df, crrt_time_df):
    vital_values = []
    for _, crrt_row in crrt_time_df.iterrows():
        hosp_id = crrt_row['encounter_block']
        crrt_time = crrt_row['recorded_dttm']
        
        # Get vitals for this hospitalization
        hosp_vitals = vital_df[vital_df['encounter_block'] == hosp_id]
        
        if len(hosp_vitals) == 0:
            vital_values.append(np.nan)
            continue
            
        # Calculate time difference and get closest
        hosp_vitals['time_diff'] = abs(hosp_vitals['recorded_dttm'] - crrt_time)
        closest_vital = hosp_vitals.loc[hosp_vitals['time_diff'].idxmin()]
        vital_values.append(closest_vital['vital_value'])
        
    return vital_values

# Get closest weight and height measurements
vitals_bmi['weight_kg'] = get_closest_vital(weight_df, vitals_bmi)
vitals_bmi['height_cm'] = get_closest_vital(height_df, vitals_bmi)

print("Summary of measurements at CRRT start:")
print("\nWeight (kg):")
print(vitals_bmi['weight_kg'].describe())
print("\nHeight (cm):")
print(vitals_bmi['height_cm'].describe())

# Vasopressors

In [None]:
# Import clif continuous meds for the cohort on vent during the required time period
meds_required_columns = [
    'hospitalization_id',
    'admin_dttm',
    'med_name',
    'med_category',
    'med_dose',
    'med_dose_unit'
]
meds_of_interest = [
    'norepinephrine', 'epinephrine', 'phenylephrine', 'vasopressin',
    'dopamine', 'angiotensin', 'dobutamine', 'milrinone', 'isoproterenol']

meds_filters = {
    'hospitalization_id': all_ids['hospitalization_id'].unique().tolist(),
    'med_category': meds_of_interest
}
meds = pyCLIF.load_data('clif_medication_admin_continuous', columns=meds_required_columns, filters=meds_filters)

In [None]:
meds['hospitalization_id']= meds['hospitalization_id'].astype(str)
meds['med_dose_unit'] = meds['med_dose_unit'].str.lower()
meds = pyCLIF.convert_datetime_columns_to_site_tz(meds,  pyCLIF.helper['timezone'])
meds['med_dose'] = pd.to_numeric(meds['med_dose'], errors='coerce')
meds = meds.merge(all_ids[['hospitalization_id', 'encounter_block']], on='hospitalization_id', how='inner')

In [None]:
# Create a summary table for each med_category and med_dose_unit combination
summary_meds_cat_dose= meds.groupby(['med_category', 'med_dose_unit']).agg(
    total_N=('med_category', 'size'),
    min=('med_dose', 'min'),
    max=('med_dose', 'max'),
    first_quantile=('med_dose', lambda x: x.quantile(0.25)),
    second_quantile=('med_dose', lambda x: x.quantile(0.5)),
    third_quantile=('med_dose', lambda x: x.quantile(0.75)),
    missing_values=('med_dose', lambda x: x.isna().sum())
).reset_index()
summary_meds_cat_dose.to_csv('../output/final/summary_meds_by_category_dose_units.csv', index=False)
## check the distrbituon of required continuous meds

In [None]:
# Diagnostic: Check which groups have all NaN values
print("Groups with all NaN med_dose values:")
for (med_category, med_dose_unit), group in meds.groupby(['med_category', 'med_dose_unit']):
    if group['med_dose'].isna().all():
        print(f"  {med_category} - {med_dose_unit}: {len(group)} rows, all NaN")

In [None]:
# SANITY CHECKS- Check the med_dose_unit for each med_category in the meds table
med_dose_unit_check = meds.groupby(['med_category', 'med_dose_unit']).size().reset_index(name='count')

# Apply the function to the DataFrame
med_dose_unit_check['unit_validity'] = med_dose_unit_check.apply(pyCLIF.check_dose_unit, axis=1)

# # Optional: Filter for invalid units
invalid_units = med_dose_unit_check[med_dose_unit_check['unit_validity'] == 'Not an acceptable unit']
print("Invalid units. These will be dropped:\n")
print(invalid_units)

In [None]:
# Filter meds to include only rows with '/hr' or '/min' in 'med_dose_unit'
meds_filtered = meds[~meds['med_dose'].isnull()].copy()
meds_filtered = meds_filtered[meds_filtered['med_dose_unit'].apply(pyCLIF.has_per_hour_or_min)].copy()
meds_filtered = meds_filtered.merge(vitals_bmi[['encounter_block', 'weight_kg']], on='encounter_block', how='left')
meds_filtered["med_dose_converted"] = meds_filtered.apply(pyCLIF.convert_dose, axis=1)

In [None]:
# Filter doses within acceptable ranges
meds_final = meds_filtered[meds_filtered.apply(pyCLIF.is_dose_within_range, axis=1, args=(outlier_cfg,))].copy()

In [None]:
meds_final.value_counts('med_category')

In [None]:
# Pivot meds_final to get med categories as columns with converted doses as values
# Check for duplicates in the index columns
duplicates = meds_final.groupby(['encounter_block', 'admin_dttm', 'med_category']).size().reset_index(name='count')
duplicates = duplicates[duplicates['count'] > 1]

if len(duplicates) > 0:
    print("Found duplicate entries for these combinations:")
    print(duplicates)
    # Keep first occurrence of each combination
    meds_final = meds_final.drop_duplicates(['encounter_block', 'admin_dttm', 'med_category'], keep='last')

In [None]:
meds_final_pivoted = meds_final.pivot(
    index=['encounter_block', 'admin_dttm'],
    columns='med_category',
    values='med_dose_converted'
).reset_index()
meds_final_pivoted.head()

In [None]:
# Rename admin_dttm to med_admin_dttm before joining
meds_final_pivoted = meds_final_pivoted.rename(columns={'admin_dttm': 'recorded_dttm'})

# Join meds_final_pivoted with labs_pivoted
clif_wide = meds_final_pivoted.merge(
    clif_wide,
    on=['encounter_block', 'recorded_dttm'],
    how='outer'  )

# Respiratory support

In [None]:
print("\n=== STEP C: Load & process respiratory support => Apply Waterfall ===\n")
rst_required_columns = [
    'hospitalization_id',
    'recorded_dttm',
    'device_name',
    'device_category',
    'mode_name', 
    'mode_category',
    'tracheostomy',
    'fio2_set',
    'lpm_set',
    'resp_rate_set',
    'peep_set',
    'resp_rate_obs',
    'tidal_volume_set', 
    'pressure_control_set',
    'pressure_support_set',
    'peak_inspiratory_pressure_set'

]

resp_support_raw = pyCLIF.load_data(
    'clif_respiratory_support',
    columns=rst_required_columns,
    filters={'hospitalization_id': all_ids['hospitalization_id'].unique().tolist()}
)

resp_support = resp_support_raw.copy()
resp_support['device_category'] = resp_support['device_category'].str.lower()
resp_support['mode_category'] = resp_support['mode_category'].str.lower()
resp_support['lpm_set'] = pd.to_numeric(resp_support['lpm_set'], errors='coerce')
resp_support['resp_rate_set'] = pd.to_numeric(resp_support['resp_rate_set'], errors='coerce')
resp_support['peep_set'] = pd.to_numeric(resp_support['peep_set'], errors='coerce')
resp_support['resp_rate_obs'] = pd.to_numeric(resp_support['resp_rate_obs'], errors='coerce')
resp_support['tracheostomy'] = pd.to_numeric(resp_support['tracheostomy'], errors='coerce').astype('Int8')
resp_support = resp_support.sort_values(['hospitalization_id', 'recorded_dttm'])


print("\n=== Apply outlier thresholds ===\n")

resp_support['fio2_set'] = pd.to_numeric(resp_support['fio2_set'], errors='coerce')
# (Optional) If FiO2 is >1 on average => scale by /100
fio2_mean = resp_support['fio2_set'].mean(skipna=True)
# If the mean is greater than 1, divide 'fio2_set' by 100
if fio2_mean and fio2_mean > 1.0:
    # Only divide values greater than 1 to avoid re-dividing already correct values
    resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] = \
        resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] / 100
    print("Updated fio2_set to be between 0.21 and 1")
else:
    print("FIO2_SET mean=", fio2_mean, "is within the required range")

pyCLIF.apply_outlier_thresholds(resp_support, 'fio2_set', *outlier_cfg['fio2_set'])
pyCLIF.apply_outlier_thresholds(resp_support, 'peep_set', *outlier_cfg['peep_set'])
pyCLIF.apply_outlier_thresholds(resp_support, 'lpm_set',  *outlier_cfg['lpm_set'])
pyCLIF.apply_outlier_thresholds(resp_support, 'resp_rate_set', *outlier_cfg['resp_rate_set'])
pyCLIF.apply_outlier_thresholds(resp_support, 'resp_rate_obs', *outlier_cfg['resp_rate_obs'])

del resp_support_raw

In [None]:
importlib.reload(waterfall)
processed_resp_support = waterfall.process_resp_support_waterfall(resp_support, 
                                                        id_col = "hospitalization_id",
                                                        verbose = True)

In [None]:
processed_resp_support = pyCLIF.convert_datetime_columns_to_site_tz(processed_resp_support, pyCLIF.helper['timezone'])
processed_resp_support = processed_resp_support.merge(all_ids[['hospitalization_id', 'encounter_block']], on='hospitalization_id', how='inner')
# Pivot meds_final to get med categories as columns with converted doses as values
# Check for duplicates in the index columns
duplicates = processed_resp_support.groupby(['encounter_block', 'recorded_dttm']).size().reset_index(name='count')
duplicates = duplicates[duplicates['count'] > 1]

if len(duplicates) > 0:
    print("Found duplicate entries for these combinations:")
    print(duplicates)
    # Keep first occurrence of each combination
    processed_resp_support = processed_resp_support.drop_duplicates(['encounter_block', 'recorded_dttm'], keep='first')
else:
    print("No duplicates in respiratory support")

In [None]:
processed_resp_support.drop(columns=['is_scaffold', 'device_cat_id', 'device_id', 'mode_cat_id', 'mode_name_id'], inplace=True)

In [None]:
clif_wide = clif_wide.merge(
    processed_resp_support,
    on=['encounter_block', 'recorded_dttm'],
    how='outer'  
)

In [None]:
strobe_counts

# Save final datasets

In [None]:
# Save final datasets as parquet files
all_ids.to_parquet('../output/intermediate/all_ids.parquet')
adt_final_stitched.to_parquet('../output/intermediate/adt_final.parquet') 
clif_wide.to_parquet('../output/intermediate/clif_wide.parquet')
crrt_stitched.to_parquet('../output/intermediate/crrt_df.parquet')

print("Saved final datasets as parquet files to output/intermediate folder")

In [None]:
# Save dataframes and print their schemas
print("\nall_ids schema:")
print(all_ids.dtypes)
all_ids.to_parquet('../output/intermediate/all_ids.parquet')

print("\nadt_final_stitched schema:")
print(adt_final_stitched.dtypes) 
adt_final_stitched.to_parquet('../output/intermediate/adt_final.parquet')

print("\nclif_wide schema:")
print(clif_wide.dtypes)
clif_wide.to_parquet('../output/intermediate/clif_wide.parquet')

print("\ncrrt_stitched schema:")
print(crrt_stitched.dtypes)
crrt_stitched.to_parquet('../output/intermediate/crrt_df.parquet')

In [None]:
import os
import matplotlib.pyplot as plt

def draw_consort_diagram(
        strobe_counts: dict,
        site_name: str ,
        output_folder: str = "./outputs"
    ):
    """
    CONSORT / STROBE-style cohort diagram for the CRRT study.

    Parameters
    ----------
    strobe_counts : dict
        Dictionary with keys produced by the pipeline.
    site_name : str
        Label to append to the saved PNG.
    output_folder : str
        Root folder where   …/final/graphs/consort_diagram_<site>.png   is saved.
    """

    # ────────────────────────── numbers we need
    start_n        = strobe_counts["A_total_adult_hospitalizations"]
    stitched_n     = strobe_counts["B_after_stitching"]
    stitched_drop  = strobe_counts["B_before_stitching"] - stitched_n
    crrt_n         = strobe_counts["C_hospitalizations_with_crrt"]
    no_crrt_drop   = stitched_n - crrt_n
    esrd_drop      = strobe_counts["D_hospitalizations_with_esrd"]
    final_n        = strobe_counts["D_hospitalizations_without_esrd"]
    early_death_n  = strobe_counts["E_encounters_early_death"]
    final_cohort_n = strobe_counts["E_encounters_after_early_death_exclusion"]

    aki_blocks     = strobe_counts.get("D_encounter_blocks_with_AKI_no_esrd", 0)
    aki_pct        = aki_blocks / final_n if final_n else 0

    # ────────────────────────── plotting scaffold
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.axis("off")
    ax.set_title(f"Cohort Selection for\nCRRT Study for {site_name}",
                 fontsize=14, fontweight="bold", pad=20)

    # main boxes top-to-bottom
    main_keys  = ["start", "stitched", "crrt", "final", "final_e"]
    y_levels   = [0.85, 0.65, 0.45, 0.25, 0.05]
    x_main     = 0.35
    main_boxes = {
        "start":   (f"Adult hospitalizations\n(date + age filters)\n(n={start_n:,})",
                    (x_main, y_levels[0])),
        "stitched":(f"After encounter stitching\n(n={stitched_n:,})",
                    (x_main, y_levels[1])),
        "crrt":    (f"Hospitalizations with CRRT\n(n={crrt_n:,})",
                    (x_main, y_levels[2])),
        "final":   (f"Hospitalizations without ESRD\n(n={final_n:,})"
                    f"\nAKI in {aki_pct:.1%} of blocks",
                    (x_main, y_levels[3])),
        "final_e": (f"Final analytic cohort\n(n={final_cohort_n:,})",
                    (x_main, y_levels[4]))
    }

    # exclusion boxes (one between each pair of mains)
    excl_info = [
        (f"Excluded:\nno stitched data\n(n={stitched_drop:,})", 0),   # above stitching box
        (f"Excluded:\nno CRRT\n(n={no_crrt_drop:,})",         1),   # above CRRT box
        (f"Excluded:\nESRD\n(n={esrd_drop:,})",               2),    # above final box
        (f"Excluded:\nearly death (n={early_death_n:,})", 3)  # above final_e box
    ]

    box_props   = dict(boxstyle="round,pad=0.4", facecolor="white",
                       edgecolor="black", linewidth=1.2)
    arrow_props = dict(arrowstyle="->", color="black", lw=1.3)

    # ── draw main boxes + vertical arrows
    for k, (txt, (x, y)) in main_boxes.items():
        ax.text(x, y, txt, ha="center", va="center", bbox=box_props, fontsize=11)

    for i in range(len(main_keys)-1):
        y1, y2 = y_levels[i] - 0.05, y_levels[i+1] + 0.05
        ax.annotate("", xy=(x_main, y2), xytext=(x_main, y1),
                    arrowprops=arrow_props)

    # ── draw exclusion boxes + side arrows
    x_excl = 0.75
    for txt, idx in excl_info:
        y_mid = (y_levels[idx] + y_levels[idx+1]) / 2
        ax.text(x_excl, y_mid, txt, ha="left", va="center",
                bbox=box_props, fontsize=9)
        ax.annotate("", xy=(x_excl-0.02, y_mid),
                    xytext=(x_main+0.05, y_mid),
                    arrowprops=arrow_props)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    # ── save
    os.makedirs(output_folder, exist_ok=True)
    save_path = os.path.join(output_folder, f"consort_diagram_{site_name}.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"Saved CONSORT diagram to: {save_path}")
    plt.show()


In [None]:
# Convert the dictionary to a DataFrame and save it as a CSV file
pd.DataFrame(list(strobe_counts.items()), columns=['Metric', 'Value']).to_csv('../output/final/strobe_counts.csv', index=False)
strobe_counts

In [None]:
# ────────────────────────── usage example
draw_consort_diagram(strobe_counts, site_name=pyCLIF.helper['site_name'], output_folder="../output/final/graphs")