# Epidemiology of CRRT- Analysis

Author: Kaveri Chhikara

Run this script after 01_cohort_identification

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
from datetime import datetime, timedelta
import json
import warnings
import pyarrow
from tableone import TableOne
import seaborn as sns
import sofa_score  
import logging
warnings.filterwarnings('ignore')

import pyCLIF
output_folder = '../output'
## import outlier json
with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
    outlier_cfg = json.load(f)

import os
output_dir = "../output/final"
os.makedirs(output_dir, exist_ok=True)

# Load CLIF wide and other cohort datasets

In [None]:
all_ids_df = pd.read_parquet(f'{output_folder}/intermediate/all_ids.parquet')
adt_final_df = pd.read_parquet(f'{output_folder}/intermediate/adt_final.parquet')
clif_wide_df = pd.read_parquet(f'{output_folder}/intermediate/clif_wide.parquet')
crrt_df = pd.read_parquet(f'{output_folder}/intermediate/crrt_df.parquet')

# Combine CLIF wide with ADT

Combine CLIF wide with ADT and forward fill location category and location type columns to get patient location info at each time point

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 1) ANNOTATE CLIF-WIDE RECORDS WITH ADT LOCATION (FORWARD-FILL ONLY)
# ─────────────────────────────────────────────────────────────────────────────
adt_int = adt_final_df[[
    "encounter_block",
    "in_dttm", "out_dttm",
    "location_category", "location_type"
]]

# 1a) Every flowsheet row matched to all ADT intervals for that encounter
merged = clif_wide_df.merge(adt_int, on="encounter_block", how="left")

# 1b) Keep only rows where recorded_dttm ∈ [in_dttm, out_dttm]
mask = (
    (merged["recorded_dttm"] >= merged["in_dttm"]) &
    (merged["recorded_dttm"] <= merged["out_dttm"])
)
annot = merged.loc[mask, 
    ["encounter_block", "recorded_dttm", "location_category", "location_type"]
].copy()

# 1c) Forward-fill per encounter_block
annot = (
    annot
    .sort_values(["encounter_block","recorded_dttm"])
    .groupby("encounter_block", as_index=False)
    .apply(lambda df: df.ffill())
    .dropna(subset=["location_category","location_type"], how="all")
    # now has a flat index: (encounter_block, recorded_dttm, ...)
)

# 1d) Re-index `annot` on the two keys
annot_indexed = annot.set_index(["encounter_block","recorded_dttm"])[
    ["location_category","location_type"]
]

# 1e) Join back onto `clif_wide_df`
clif_wide_df = (
    clif_wide_df
      .set_index(["encounter_block","recorded_dttm"])
      .join(annot_indexed, how="left")
      .reset_index()
)

# Now clif_wide_annot has the ADT location forward-filled from the last interval that
# covered each timestamp, with no peeking into future ADT rows.

In [None]:
print("=== all_ids_df dtypes ===")
print(all_ids_df.dtypes)
print("\n=== clif_wide_df dtypes ===")
print(clif_wide_df.dtypes)
print("\n=== crrt_df dtypes ===")
print(crrt_df.dtypes)

# (A) Summary by Time from CRRT Start

* Pre- 24 hr CRRT start
* Post-24 hr CRRT start
* Post-72 hr CRRT start

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 0) CONSTANTS & INPUTS
# ─────────────────────────────────────────────────────────────────────────────
os.makedirs(os.path.join(output_folder, "final"), exist_ok=True)

# path & mapping for SOFA
tables_path = pyCLIF.helper['tables_path']
id_mappings = all_ids_df[['encounter_block','hospitalization_id']].drop_duplicates()

# lists of variables
demog_cols    = ["age_at_admission","sex_category","race_category","ethnicity_category"]
adt_cols      = ["location_category","location_type"]
device_col    = "device_category"

continuous_vars = [
    # vasoactives
    "angiotensin","dobutamine","dopamine","epinephrine",
    "norepinephrine","phenylephrine","vasopressin",
    # labs
    "bicarbonate","bun","calcium_total","chloride","creatinine","magnesium",
    "glucose_serum","lactate","potassium","sodium","ph_arterial","po2_arterial",
    # vents
    "fio2_set","peep_set","resp_rate_set","tidal_volume_set",
    "pressure_control_set","pressure_support_set","peak_inspiratory_pressure_set",
]

crrt_vars = [
    "blood_flow_rate",
    "pre_filter_replacement_fluid_rate",
    "post_filter_replacement_fluid_rate",
    "dialysate_flow_rate",
    "ultrafiltration_out",
]

# ─────────────────────────────────────────────────────────────────────────────
# 1) FIRST CRRT START PER ENCOUNTER
# ─────────────────────────────────────────────────────────────────────────────
first_crrt = (
    crrt_df
    .groupby("encounter_block", as_index=False)["recorded_dttm"]
    .min()
    .rename(columns={"recorded_dttm":"first_crrt_time"})
)

# pull demographics + death
demog = (
    all_ids_df
    .set_index("encounter_block")[demog_cols + ["death_dttm_proxy"]]
)

# ─────────────────────────────────────────────────────────────────────────────
# 2) helper: summarize any window
# ─────────────────────────────────────────────────────────────────────────────
def summarize_window(start_offset_h: int, end_offset_h: int, include_crrt: bool):
    """
    Summarize for each encounter_block in [first_crrt_time+start_offset_h, first_crrt_time+end_offset_h]:
     - demographics + died_within_window
     - median [Q1,Q3] of continuous_vars
     - last non-null ADT location + type
     - last non-null device_category
     - SOFA components + total
     - if include_crrt: last CRRT mode + median [Q1,Q3] of crrt_vars
    Returns DataFrame indexed by encounter_block.
    """

    # 2a) window definitions
    start_off = pd.Timedelta(hours=start_offset_h)
    end_off   = pd.Timedelta(hours=end_offset_h)
    bounds    = first_crrt.copy()
    bounds["win_start"] = bounds["first_crrt_time"] + start_off
    bounds["win_end"]   = bounds["first_crrt_time"] + end_off
    bnd        = bounds.set_index("encounter_block")

    # 2b) CLIF-wide flowsheet slice
    cw = (
        clif_wide_df
        .merge(bounds[["encounter_block","win_start","win_end"]], on="encounter_block", how="inner")
    )
    cw = cw[(cw.recorded_dttm >= cw.win_start) & (cw.recorded_dttm <= cw.win_end)]

    # continuous vars: median & IQR
    med = cw.groupby("encounter_block")[continuous_vars].median()
    q1  = cw.groupby("encounter_block")[continuous_vars].quantile(0.25)
    q3  = cw.groupby("encounter_block")[continuous_vars].quantile(0.75)
    cont = (
        med.add_suffix("_median")
           .join(q1.add_suffix("_q1"))
           .join(q3.add_suffix("_q3"))
    )

    # ADT location/type: last non-null in window
    loc_win = cw.dropna(subset=adt_cols, how="all")
    last_loc = (
        loc_win.sort_values("recorded_dttm")
               .groupby("encounter_block")[adt_cols]
               .last()
    )

    # device_category: last non-null in window
    dev_win = cw.dropna(subset=[device_col])
    last_dev = (
        dev_win.sort_values("recorded_dttm")
               .groupby("encounter_block")[[device_col]]
               .last()
    )
    # merge death times into  window bounds
    win = bnd.join(demog[["death_dttm_proxy"]], how="left")

    # mortality flag
    death_flag = (
        win["death_dttm_proxy"]
        .between(win["win_start"], win["win_end"])
        .rename("died_within_window")
    )

    # assemble core
    df = (
        bnd
        .join(demog.drop(columns="death_dttm_proxy"))
        .join(death_flag)
        .join(cont)
        .join(last_loc)
        .join(last_dev)
    )

    # 2c) SOFA
    sofa_in = bounds[["encounter_block"]].copy()
    sofa_in["start_dttm"] = bounds["win_start"]
    sofa_in["stop_dttm"]  = bounds["win_end"]
    sofa_out = sofa_score.compute_sofa(
        ids_w_dttm            = sofa_in,
        tables_path           = tables_path,
        use_hospitalization_id= False,
        id_mapping            = id_mappings,
        helper_module         = pyCLIF,
        output_filepath       = None
    )
    sofa_cols = [
      "sofa_cv_97","sofa_coag","sofa_renal",
      "sofa_liver","sofa_resp","sofa_cns","sofa_total"
    ]
    sofa_df = sofa_out.set_index("encounter_block")[sofa_cols]
    df = df.join(sofa_df)

    # 2d) CRRT-specific (optional)
    if include_crrt:
        cr = crrt_df.merge(bounds.reset_index(), on="encounter_block", how="inner")
        cr_win = cr[(cr.recorded_dttm >= cr.win_start) & (cr.recorded_dttm <= cr.win_end)]

        # last CRRT mode
        mode_win = cr_win.dropna(subset=["crrt_mode_category"])
        last_mode = (
            mode_win.sort_values("recorded_dttm")
                    .groupby("encounter_block")["crrt_mode_category"]
                    .last()
                    .rename("last_crrt_mode")
        )

        # CRRT cont. vars
        cm = cr_win.groupby("encounter_block")[crrt_vars].median()
        c1 = cr_win.groupby("encounter_block")[crrt_vars].quantile(0.25)
        c3 = cr_win.groupby("encounter_block")[crrt_vars].quantile(0.75)
        crrt_cont = (
            cm.add_suffix("_median")
              .join(c1.add_suffix("_q1"))
              .join(c3.add_suffix("_q3"))
        )

        df = df.join(last_mode).join(crrt_cont)

    return df.drop(columns=["first_crrt_time","win_start","win_end"])


# ─────────────────────────────────────────────────────────────────────────────
# 3) BUILD THE THREE WINDOWS
# ─────────────────────────────────────────────────────────────────────────────
pre24  = summarize_window(-24,  0, include_crrt=False)
post24 = summarize_window(  0, 24, include_crrt=True)
post72 = summarize_window(  0, 72, include_crrt=True)

# tag & stack
for df, lbl in [(pre24,"Pre-24h"),(post24,"Post-24h"),(post72,"Post-72h")]:
    df["window"] = lbl

combined = pd.concat([pre24,post24,post72], axis=0).reset_index()
combined.to_csv(os.path.join(output_folder,"intermediate","combined_summary.csv"), index=False)

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 4) MAKE TABLEONE FOR EACH WINDOW (with SOFA)
# ─────────────────────────────────────────────────────────────────────────────
def make_tableone(summary_df, window_label):
    df = summary_df.reset_index()

    # base set of categorical & continuous variables
    cat_vars = [
      "sex_category","race_category","ethnicity_category",
      "location_category","location_type","device_category",
      "died_within_window"
    ]

    cont_vars = (
      ["age_at_admission"]
      + continuous_vars
      # add SOFA components here:
      + ["sofa_cv_97","sofa_coag","sofa_renal",
         "sofa_liver","sofa_resp","sofa_cns","sofa_total"]
    )

    # for post-windows also include CRRT mode + CRRT vars
    if "last_crrt_mode" in df.columns:
        cat_vars.append("last_crrt_mode")
        cont_vars += crrt_vars

    # flatten any *_median → base var and drop *_q1/_q3
    for v in continuous_vars + crrt_vars:
        m = f"{v}_median"
        if m in df:
            df[v] = df.pop(m)
        for suf in ("_q1","_q3"):
            col = v + suf
            if col in df:
                df.pop(col)

    # Round vasopressors to 4 decimal places
    vasopressors = ["angiotensin", "dobutamine", "dopamine", "epinephrine",
                    "norepinephrine", "phenylephrine", "vasopressin"]
    for v in vasopressors:
        if v in df.columns:
            df[v] = df[v].round(4)

    # build the TableOne; for all cont_vars it will automatically compute
    # median [IQR] because we pass them in `nonnormal=`
    tbl = TableOne(
      df,
      columns     = cat_vars + cont_vars,
      categorical = cat_vars,
      nonnormal   = cont_vars,
      groupby     = None
    )

    # save & return
    fname = f"table1_{window_label.replace(' ','_')}.csv"
    tbl.to_csv(os.path.join(output_folder,"final",fname))
    print("Table saved to:", f"{output_folder}/final/{fname}")
    return tbl

# regenerate your tables:
table_pre24  = make_tableone(pre24,  "Pre-24h")
table_post24 = make_tableone(post24, "Post-24h")
table_post72 = make_tableone(post72, "Post-72h")


In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 1) Extract the underlying pandas DataFrames from each TableOne
# ─────────────────────────────────────────────────────────────────────────────
df_pre24  = pd.DataFrame(table_pre24.tableone).rename(columns={0: "Pre-24h",   "Overall": "pre24_summary", "Missing": "pre24_missing"})
df_post24 = pd.DataFrame(table_post24.tableone).rename(columns={0: "Post-24h", "Overall": "post24_summary", "Missing": "post24_missing"})
df_post72 = pd.DataFrame(table_post72.tableone).rename(columns={0: "Post-72h", "Overall": "post72_summary", "Missing": "post72_missing"})

# ─────────────────────────────────────────────────────────────────────────────
# 2) Join them on their index (the "characteristic" labels) using post24/72 order
# ─────────────────────────────────────────────────────────────────────────────
# Get reference order from post24 or post72 if they exist
ref_order = None
if not df_post24.empty:
    ref_order = df_post24.index.tolist()
elif not df_post72.empty:
    ref_order = df_post72.index.tolist()

combined = (
    df_pre24
      .join(df_post24, how="outer")
      .join(df_post72, how="outer")
)

# Reorder rows if we have a reference order
if ref_order:
    # Only use existing indices from reference order
    valid_indices = [idx for idx in ref_order if idx in combined.index]
    # Add any remaining indices that weren't in reference
    remaining = combined.index.difference(valid_indices).tolist()
    combined = combined.reindex(valid_indices + remaining)

# ─────────────────────────────────────────────────────────────────────────────
# 3) Clean up the index name and display
# ─────────────────────────────────────────────────────────────────────────────
combined.index.name = "Characteristic"
combined.reset_index(inplace=True)

# Drop all *_missing columns
missing_cols = [col for col in combined.columns if col.endswith('_missing')]
combined = combined.drop(columns=missing_cols)

# ─────────────────────────────────────────────────────────────────────────────
# 4) Save to CSV if you like
# ─────────────────────────────────────────────────────────────────────────────

filename = f"table1_all_windows_{pyCLIF.helper['site_name']}.csv"
combined.to_csv(os.path.join(output_folder, "final", filename), index=False)

# (B) Summary by Time, Mortality and CRRT status


* Look at summary of characteristics in eight subcohorts 
     * Survivors - On CRRT - Post-24hr
     * Survivors - Off CRRT - Post-24hr
     * Survivors - On CRRT - Post-72hr
     * Survivors - Off CRRT- Post-72hr
     * Non-survivors - On CRRT - Post-24hr
     * Non-survivors - Off CRRT - Post-24hr
     * Non-survivors - On CRRT - Post-72hr
     * Non-survivors - Off CRRT - Post-72hr

* Derive per-encounter flags
    * Survivor vs non-survivor: from all_ids_df.mortality + discharge_dttm as a proxy for death_dttm. Mortality is defined as discharge category == "Expired" or "Hospice"
    * CRRT end time: for each encounter_block, take the max recorded_dttm in crrt_df as the end of CRRT
    * “Still on CRRT” at 24 h (and 72 h): compare first_crrt_time + N h to that end time. If CRRT end ≥ window‐end ⇒ “still on CRRT”. Else ⇒ “off CRRT”


In [None]:
log_path = os.path.join(output_folder, "final", "table_one.log")
os.makedirs(os.path.dirname(log_path), exist_ok=True)

# 1) Get (or create) our logger
logger = logging.getLogger("table_one")
logger.propagate = False
logger.setLevel(logging.INFO)

# 2) Remove any handlers already registered on this logger
if logger.hasHandlers():
    logger.handlers.clear()

# 3) Create and configure our console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

# 4) Create and configure our file handler
fh = logging.FileHandler(log_path, mode="a")
fh.setLevel(logging.INFO)

# 5) Create a shared formatter
formatter = logging.Formatter(
    "%(asctime)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

ch.setFormatter(formatter)
fh.setFormatter(formatter)

# 6) Attach handlers to the logger
logger.addHandler(ch)
logger.addHandler(fh)

In [None]:
first_crrt = (
    crrt_df
    .groupby("encounter_block", as_index=False)["recorded_dttm"]
    .min()
    .rename(columns={"recorded_dttm": "first_crrt_time"})
)

end_crrt = (
    crrt_df
    .groupby("encounter_block", as_index=False)["recorded_dttm"]
    .max()
    .rename(columns={"recorded_dttm": "end_crrt_time"})
)

enc = (
    all_ids_df[["encounter_block", "mortality", "discharge_dttm", "death_dttm_proxy"]]
    .merge(first_crrt, on="encounter_block")
    .merge(end_crrt,   on="encounter_block"))

# flags: still on CRRT at 24 h and 72 h
enc["on_crrt_at_24h"] = (
    enc["end_crrt_time"] >= enc["first_crrt_time"] + pd.Timedelta(hours=24)
)
enc["on_crrt_at_72h"] = (
    enc["end_crrt_time"] >= enc["first_crrt_time"] + pd.Timedelta(hours=72)
)

# death before or at 24 h post‐start
enc["died_within_24h"] = (
    enc["death_dttm_proxy"].notna() &
    (enc["death_dttm_proxy"] <= enc["first_crrt_time"] + pd.Timedelta(hours=24))
)
# death before or at 72 h post‐start
enc["died_within_72h"] = (
    enc["death_dttm_proxy"].notna() &
    (enc["death_dttm_proxy"] <= enc["first_crrt_time"] + pd.Timedelta(hours=72))
)

# sanity checks
# ───────────────────────────────────────────────
# Sanity checks with logging
# ───────────────────────────────────────────────
if not (enc["first_crrt_time"] <= enc["end_crrt_time"]).all():
    logger.error("Some first_crrt_time > end_crrt_time!!! That's not cool!")
else:
    logger.info("All first_crrt_time ≤ end_crrt_time.")

logger.info(f"Total encounters: {len(enc)}")
mort_counts = enc["mortality"].value_counts().to_dict()
logger.info(f"Mortality distribution: {mort_counts}")
p24 = enc["on_crrt_at_24h"].mean()
p72 = enc["on_crrt_at_72h"].mean()
logger.info(f"Still on CRRT at 24 h: {p24:.1%}")
logger.info(f"Still on CRRT at 72 h: {p72:.1%}")
d24 = enc["died_within_24h"].mean()
d72 = enc["died_within_72h"].mean()
logger.info(f"Died within 24 h of CRRT start: {d24:.1%}")
logger.info(f"Died within 72 h of CRRT start: {d72:.1%}")

In [None]:
# died in the 24 h before CRRT start?
enc["died_pre24h"] = (
    enc["death_dttm_proxy"].notna() &
    (enc["death_dttm_proxy"] >= enc["first_crrt_time"] - pd.Timedelta(hours=24)) &
    (enc["death_dttm_proxy"]  < enc["first_crrt_time"])
)

# survivors at CRRT start = nobody died before or at start:
enc["survived_to_start"] = ~enc["died_pre24h"]

# survivors still on CRRT at 24 h:
enc["surv_oncrrt_post24"] = (
    enc["survived_to_start"] &
    enc["on_crrt_at_24h"] &
    ~enc["died_within_24h"]
)

# survivors who have come off CRRT by 24 h:
enc["surv_offcrrt_post24"] = (
    enc["survived_to_start"] &
    ~enc["on_crrt_at_24h"] &
    ~enc["died_within_24h"]
)

# died during the first 24 h on/off CRRT
enc["non_surv_oncrrt_post24"]  = enc["died_within_24h"] & enc["on_crrt_at_24h"]
enc["non_surv_offcrrt_post24"] = enc["died_within_24h"] & ~enc["on_crrt_at_24h"]
# died between 24 h and 72 h (exclusive of the first 24 h)
enc["died_24_72h"] = enc["died_within_72h"] & ~enc["died_within_24h"]


# Survivors at 72h still on CRRT
enc["surv_oncrrt_post72"] = (
    ~enc["died_within_24h"] &      # Survived first 24h
    ~enc["died_within_72h"] &      # Still alive at 72h
    enc["on_crrt_at_72h"]          # On CRRT at 72h
)

# Survivors at 72h off CRRT
enc["surv_offcrrt_post72"] = (
    ~enc["died_within_24h"] &      # Survived first 24h
    ~enc["died_within_72h"] &      # Still alive at 72h
    ~enc["on_crrt_at_72h"]         # Off CRRT at 72h
)

# Non-survivors (died 24-72h) still on CRRT
enc["non_surv_oncrrt_post72"] = (
    ~enc["died_within_24h"] &      # Survived first 24h
    enc["died_within_72h"] &       # Died by 72h
    enc["on_crrt_at_72h"]          # On CRRT at time of death/72h
)

# Non-survivors (died 24-72h) off CRRT
enc["non_surv_offcrrt_post72"] = (
    ~enc["died_within_24h"] &      # Survived first 24h
    enc["died_within_72h"] &       # Died by 72h
    ~enc["on_crrt_at_72h"]         # Off CRRT at time of death/72h
)

# nobody died before CRRT start
if enc["died_pre24h"].any():
    logger.error(f"{enc['died_pre24h'].sum()} patients died pre-CRRT!")
    bad_pre = enc[enc["death_dttm_proxy"] < enc["first_crrt_time"]].copy()
    bad_pre["delta"] = bad_pre["first_crrt_time"] - bad_pre["death_dttm_proxy"]
    print(bad_pre[["encounter_block","death_dttm_proxy","first_crrt_time","delta"]].head(20))
else:
    logger.info("No patients died in the 24 h before CRRT start.")

# ensure post-24 h decedents aren’t in the post-72 h cohorts
bad = enc.loc[enc["died_within_24h"], "surv_oncrrt_post72"].any() or \
      enc.loc[enc["died_within_24h"], "surv_offcrrt_post72"].any()
if bad:
    logger.error("Some 24h decedents slipped into 72h survivor groups!")
else:
    logger.info("Post-72h survivor groups correctly exclude all 24h decedents.")

In [None]:
# Comprehensive sanity checks for 72h cohorts

logger.info("=== SANITY CHECKS FOR 72H COHORTS ===")

# Basic counts
n_total = len(enc)
n_died_before_24h = enc["died_within_24h"].sum()
n_survived_24h = (~enc["died_within_24h"]).sum()
n_died_24_72h = ((~enc["died_within_24h"]) & enc["died_within_72h"]).sum()
n_survived_72h = ((~enc["died_within_24h"]) & (~enc["died_within_72h"])).sum()

logger.info("\n1. OVERALL POPULATION")
logger.info(f"   Total patients: {n_total}")
logger.info(f"   Died within 24h: {n_died_before_24h}")
logger.info(f"   Survived to 24h: {n_survived_24h}")
logger.info(f"   - Died between 24-72h: {n_died_24_72h}")
logger.info(f"   - Survived to 72h: {n_survived_72h}")
logger.info(f"   Check: {n_died_24_72h + n_survived_72h} should equal {n_survived_24h}")

# 72h cohort counts
n_surv_on_72 = enc["surv_oncrrt_post72"].sum()
n_surv_off_72 = enc["surv_offcrrt_post72"].sum()
n_non_surv_on_72 = enc["non_surv_oncrrt_post72"].sum()
n_non_surv_off_72 = enc["non_surv_offcrrt_post72"].sum()
n_total_72h_cohorts = n_surv_on_72 + n_surv_off_72 + n_non_surv_on_72 + n_non_surv_off_72

logger.info("\n2. 72H COHORT BREAKDOWN")
logger.info(f"   Survivors at 72h:")
logger.info(f"   - Still on CRRT: {n_surv_on_72}")
logger.info(f"   - Off CRRT: {n_surv_off_72}")
logger.info(f"   - Total survivors: {n_surv_on_72 + n_surv_off_72}")
logger.info(f"   Non-survivors (died 24-72h):")
logger.info(f"   - On CRRT at death: {n_non_surv_on_72}")
logger.info(f"   - Off CRRT at death: {n_non_surv_off_72}")
logger.info(f"   - Total non-survivors: {n_non_surv_on_72 + n_non_surv_off_72}")
logger.info(f"   TOTAL in 72h cohorts: {n_total_72h_cohorts}")

# Key check: 72h cohorts should equal all 24h survivors
logger.info("\n3. PRIMARY CHECK")
logger.info(f"   Survived to 24h: {n_survived_24h}")
logger.info(f"   Total in 72h cohorts: {n_total_72h_cohorts}")
if n_survived_24h == n_total_72h_cohorts:
    logger.info("   ✓ PASS: All 24h survivors accounted for in 72h cohorts")
else:
    logger.warning(f"   ✗ FAIL: Missing {n_survived_24h - n_total_72h_cohorts} patients")

# Check survivors at 72h partition correctly
logger.info("\n4. SURVIVOR PARTITION CHECK")
logger.info(f"   Expected survivors at 72h: {n_survived_72h}")
logger.info(f"   Actual (on + off CRRT): {n_surv_on_72 + n_surv_off_72}")
if n_survived_72h == n_surv_on_72 + n_surv_off_72:
    logger.info("   ✓ PASS: Survivors correctly partitioned by CRRT status")
else:
    logger.warning("   ✗ FAIL: Survivor counts don't match")

# Check non-survivors partition correctly
logger.info("\n5. NON-SURVIVOR PARTITION CHECK")
logger.info(f"   Expected deaths 24-72h: {n_died_24_72h}")
logger.info(f"   Actual (on + off CRRT): {n_non_surv_on_72 + n_non_surv_off_72}")
if n_died_24_72h == n_non_surv_on_72 + n_non_surv_off_72:
    logger.info("   ✓ PASS: Non-survivors correctly partitioned by CRRT status")
else:
    logger.warning("   ✗ FAIL: Non-survivor counts don't match")

# Check for overlaps (no patient should be in multiple cohorts)
logger.info("\n6. MUTUAL EXCLUSIVITY CHECKS")
overlap_survivor = enc["surv_oncrrt_post72"] & enc["surv_offcrrt_post72"]
overlap_non_surv = enc["non_surv_oncrrt_post72"] & enc["non_surv_offcrrt_post72"]
overlap_surv_non_surv = (enc["surv_oncrrt_post72"] | enc["surv_offcrrt_post72"]) & \
                        (enc["non_surv_oncrrt_post72"] | enc["non_surv_offcrrt_post72"])

logger.info(f"   Overlap between survivor cohorts: {overlap_survivor.sum()}")
logger.info(f"   Overlap between non-survivor cohorts: {overlap_non_surv.sum()}")
logger.info(f"   Overlap between survivors and non-survivors: {overlap_surv_non_surv.sum()}")
if not any([overlap_survivor.any(), overlap_non_surv.any(), overlap_surv_non_surv.any()]):
    logger.info("   ✓ PASS: All cohorts are mutually exclusive")
else:
    logger.warning("   ✗ FAIL: Found overlapping patients")

# Verify CRRT status logic
logger.info("\n7. CRRT STATUS CONSISTENCY")
# For survivors on CRRT at 72h
on_crrt_72h_survivors = enc[enc["surv_oncrrt_post72"]]
logger.info(f"   Survivors 'on CRRT' at 72h: {len(on_crrt_72h_survivors)}")
logger.info(f"   - Actually on CRRT: {on_crrt_72h_survivors['on_crrt_at_72h'].sum()}")
logger.info(f"   - Actually off CRRT: {(~on_crrt_72h_survivors['on_crrt_at_72h']).sum()}")

# For survivors off CRRT at 72h
off_crrt_72h_survivors = enc[enc["surv_offcrrt_post72"]]
logger.info(f"   Survivors 'off CRRT' at 72h: {len(off_crrt_72h_survivors)}")
logger.info(f"   - Actually on CRRT: {off_crrt_72h_survivors['on_crrt_at_72h'].sum()}")
logger.info(f"   - Actually off CRRT: {(~off_crrt_72h_survivors['on_crrt_at_72h']).sum()}")

# Summary
logger.info("\n8. FINAL SUMMARY")
logger.info(f"   24h analysis includes: {n_total} patients")
logger.info(f"   72h analysis includes: {n_survived_24h} patients (excluded {n_died_before_24h} early deaths)")
logger.info(f"   72h breakdown:")
logger.info(f"   - {n_survived_72h} survived to 72h ({n_surv_on_72} on CRRT, {n_surv_off_72} off)")
logger.info(f"   - {n_died_24_72h} died 24-72h ({n_non_surv_on_72} on CRRT, {n_non_surv_off_72} off)")

# If there are discrepancies, identify missing patients
if n_survived_24h != n_total_72h_cohorts:
    logger.warning("\n9. MISSING PATIENT ANALYSIS")
    in_72h = enc["surv_oncrrt_post72"] | enc["surv_offcrrt_post72"] | \
            enc["non_surv_oncrrt_post72"] | enc["non_surv_offcrrt_post72"]
    survived_24h_mask = ~enc["died_within_24h"]
    missing_mask = survived_24h_mask & ~in_72h
    missing_patients = enc[missing_mask]

    logger.warning(f"   Missing patients: {len(missing_patients)}")
    logger.warning(f"   - Died within 72h: {missing_patients['died_within_72h'].sum()}")
    logger.warning(f"   - Still alive at 72h: {(~missing_patients['died_within_72h']).sum()}")
    logger.warning(f"   - On CRRT at 72h: {missing_patients['on_crrt_at_72h'].sum()}")
    logger.warning(f"   - Off CRRT at 72h: {(~missing_patients['on_crrt_at_72h']).sum()}")

In [None]:
# Step 1: Function to create subcohort tables
def create_subcohort_table(window_df, enc_df, subcohort_flag, window_name, subcohort_name):
    """
    Filter a window summary to only include encounters in a specific subcohort
    """
    # Get encounters in this subcohort
    subcohort_encounters = enc_df[enc_df[subcohort_flag]]['encounter_block'].tolist()

    # Get unique patients (since encounter_block is already unique per admission)
    n_admissions = len(subcohort_encounters)

    # Filter window summary to only these encounters
    subcohort_data = window_df[window_df.index.isin(subcohort_encounters)]

    # Create TableOne using the existing function
    table = make_tableone(subcohort_data, f"{window_name}_{subcohort_name}")

    return table, n_admissions

# Step 2: Define all subcohorts
subcohorts = [
    # Pre-24h (only one subcohort - all alive at CRRT start)
    {'window_df': pre24, 'flag': 'survived_to_start', 'window': 'Pre-24h', 'name': 'All'},

    # Post-24h subcohorts
    {'window_df': post24, 'flag': 'surv_oncrrt_post24', 'window': 'Post-24h', 'name': 'Survivors_on_CRRT'},
    {'window_df': post24, 'flag': 'surv_offcrrt_post24', 'window': 'Post-24h', 'name': 'Survivors_off_CRRT'},
    {'window_df': post24, 'flag': 'non_surv_oncrrt_post24', 'window': 'Post-24h', 'name': 'Non-survivors_on_CRRT'},
    {'window_df': post24, 'flag': 'non_surv_offcrrt_post24', 'window': 'Post-24h', 'name': 'Non-survivors_off_CRRT'},

    # Post-72h subcohorts
    {'window_df': post72, 'flag': 'surv_oncrrt_post72', 'window': 'Post-72h', 'name': 'Survivors_on_CRRT'},
    {'window_df': post72, 'flag': 'surv_offcrrt_post72', 'window': 'Post-72h', 'name': 'Survivors_off_CRRT'},
    {'window_df': post72, 'flag': 'non_surv_oncrrt_post72', 'window': 'Post-72h', 'name': 'Non-survivors_on_CRRT'},
    {'window_df': post72, 'flag': 'non_surv_offcrrt_post72', 'window': 'Post-72h', 'name': 'Non-survivors_off_CRRT'},
]

# Step 3: Generate tables for each subcohort
subcohort_tables = {}
admission_counts = {}

for sc in subcohorts:
    table, n_admissions = create_subcohort_table(
        sc['window_df'],
        enc,
        sc['flag'],
        sc['window'],
        sc['name']
    )
    key = f"{sc['window']}_{sc['name']}"
    subcohort_tables[key] = table
    admission_counts[key] = n_admissions
    print(f"Created table for {key}: {n_admissions} admissions")

# Step 4: Extract and combine all tables
combined_dfs = []
column_names = []

for name, table in subcohort_tables.items():
    if table is not None:
        # Extract the TableOne dataframe and keep index as is
        df = pd.DataFrame(table.tableone)
        
        # Get the "Overall" column (usually the second column after "Missing")
        if "Overall" in df.columns:
            summary_col = df[["Overall"]]
        else:
            # If no "Overall" column, take the second column (index 1)
            summary_col = df.iloc[:, [1]]
            
        # Rename the column to the subcohort name
        col_name = name.replace("_", " ")
        summary_col.columns = [col_name]
        
        combined_dfs.append(summary_col)
        column_names.append(col_name)

# Combine all dataframes
if combined_dfs:
    # Start with the first dataframe
    combined_summary = combined_dfs[0]
    
    # Join the rest
    for df in combined_dfs[1:]:
        combined_summary = combined_summary.join(df, how='outer')

    # Add admission counts as a new row
    admission_row = pd.DataFrame(
        [[admission_counts.get(col.replace(" ", "_"), "") for col in combined_summary.columns]],
        columns=combined_summary.columns,
        index=["Admissions n"]
    )
    
    # Add patient counts row
    patient_counts_row = []
    for sc in subcohorts:
        key = f"{sc['window']}_{sc['name']}"
        col_name = key.replace("_", " ")
        if col_name in combined_summary.columns:
            subcohort_encounters = enc[enc[sc['flag']]]['encounter_block'].tolist()
            unique_patients = all_ids_df[all_ids_df['encounter_block'].isin(subcohort_encounters)]['patient_id'].nunique()
            patient_counts_row.append(str(unique_patients))

    patient_row = pd.DataFrame(
        [patient_counts_row],
        columns=combined_summary.columns,
        index=["Patients n"]
    )

    # Combine with main table
    combined_summary = pd.concat([admission_row, patient_row, combined_summary])

    # Save the final combined table
    output_file = f"{output_folder}/intermediate/table1_all_subcohorts_{pyCLIF.helper['site_name'].lower()}.csv"
    combined_summary.to_csv(output_file)
    print(f"\nFinal table saved to: {output_file}")

    # Display summary
    print(f"\nTotal rows in combined table: {len(combined_summary)}")

    # Save individual subcohort tables
    for name, table in subcohort_tables.items():
        if table is not None:
            individual_file = f"{output_folder}/intermediate/table1_{name}.csv"
            table.to_csv(individual_file)
else:
    print("No tables were created!")

We will have to combine non-survivors post 24 into one group, and not divide by crrt status, to avoid the possibility of sharing cohort n < 5. Same for non surv post-72

In [None]:
# ============================================================================
# CREATE SIMPLIFIED FLAGS FOR COMBINED NON-SURVIVORS
# ============================================================================

# Create simplified flags that combine non-survivors regardless of CRRT status
enc['non_surv_post24'] = enc['died_within_24h']  # All 24h non-survivors (on/off CRRT combined)
enc['non_surv_post72'] = enc['died_24_72h']      # All 72h non-survivors (on/off CRRT combined)

print("Created simplified flags:")
print(f"  non_surv_post24: {enc['non_surv_post24'].sum()} patients")
print(f"  non_surv_post72: {enc['non_surv_post72'].sum()} patients")

# ============================================================================
# SIMPLIFIED TABLE 1 GENERATION WITH NEW COHORT STRUCTURE
# ============================================================================

# Step 2: Define simplified subcohorts using the new flags
simplified_subcohorts = [
    # Pre-24h (only one subcohort - all alive at CRRT start)
    {'window_df': pre24, 'flag': 'survived_to_start', 'window': 'Pre-24h', 'name': 'All'},

    # Post-24h subcohorts (simplified)
    {'window_df': post24, 'flag': 'surv_oncrrt_post24', 'window': 'Post-24h', 'name': 'Survivors_on_CRRT'},
    {'window_df': post24, 'flag': 'surv_offcrrt_post24', 'window': 'Post-24h', 'name': 'Survivors_off_CRRT'},
    {'window_df': post24, 'flag': 'non_surv_post24', 'window': 'Post-24h', 'name': 'Non_survivors_combined'},

    # Post-72h subcohorts (simplified)  
    {'window_df': post72, 'flag': 'surv_oncrrt_post72', 'window': 'Post-72h', 'name': 'Survivors_on_CRRT'},
    {'window_df': post72, 'flag': 'surv_offcrrt_post72', 'window': 'Post-72h', 'name': 'Survivors_off_CRRT'},
    {'window_df': post72, 'flag': 'non_surv_post72', 'window': 'Post-72h', 'name': 'Non_survivors_combined'},
]

# Step 3: Generate tables for each simplified subcohort using existing function
simplified_subcohort_tables = {}
simplified_admission_counts = {}

print("\nGenerating simplified subcohort tables:")
for sc in simplified_subcohorts:
    table, n_admissions = create_subcohort_table(
        sc['window_df'],
        enc,
        sc['flag'],
        sc['window'],
        sc['name']
    )
    key = f"{sc['window']}_{sc['name']}"
    simplified_subcohort_tables[key] = table
    simplified_admission_counts[key] = n_admissions
    print(f"  Created table for {key}: {n_admissions} admissions")

# Step 4: Extract and combine simplified tables using existing logic
simplified_combined_dfs = []

for name, table in simplified_subcohort_tables.items():
    if table is not None:
        # Extract the TableOne dataframe
        df = pd.DataFrame(table.tableone)

        # Get the "Overall" column
        if "Overall" in df.columns:
            summary_col = df[["Overall"]]
        else:
            summary_col = df.iloc[:, [1]]

        # Rename the column to the subcohort name
        col_name = name.replace("_", " ")
        summary_col.columns = [col_name]

        simplified_combined_dfs.append(summary_col)

# Combine all simplified dataframes
if simplified_combined_dfs:
    # Start with the first dataframe
    simplified_combined_summary = simplified_combined_dfs[0]

    # Join the rest
    for df in simplified_combined_dfs[1:]:
        simplified_combined_summary = simplified_combined_summary.join(df, how='outer')

    # Add admission counts as a new row
    simplified_admission_row = pd.DataFrame(
        [[simplified_admission_counts.get(col.replace(" ", "_"), "") for col in simplified_combined_summary.columns]],
        columns=simplified_combined_summary.columns,
        index=["Admissions n"]
    )

    # Add patient counts row
    simplified_patient_counts_row = []
    for sc in simplified_subcohorts:
        key = f"{sc['window']}_{sc['name']}"
        col_name = key.replace("_", " ")
        if col_name in simplified_combined_summary.columns:
            subcohort_encounters = enc[enc[sc['flag']]]['encounter_block'].tolist()
            unique_patients = all_ids_df[all_ids_df['encounter_block'].isin(subcohort_encounters)]['patient_id'].nunique()
            simplified_patient_counts_row.append(str(unique_patients))

    simplified_patient_row = pd.DataFrame(
        [simplified_patient_counts_row],
        columns=simplified_combined_summary.columns,
        index=["Patients n"]
    )

    # Combine with main table
    final_simplified_table = pd.concat([simplified_admission_row, simplified_patient_row, simplified_combined_summary])

    # Save the simplified table
    simplified_output_file = f"{output_folder}/final/table1_subgroups2_{pyCLIF.helper['site_name'].lower()}.csv"
    final_simplified_table.to_csv(simplified_output_file)
    print(f"\n✓ Simplified Table 1 saved to: {simplified_output_file}")

    # Also save as generic filename for multi-site compatibility
    generic_simplified_file = f"{output_folder}/final/table1_subgroups2.csv"
    final_simplified_table.to_csv(generic_simplified_file)
    print(f"✓ Also saved as: {generic_simplified_file}")

    # Display the new simplified cohort structure
    print(f"\n{'='*60}")
    print("SIMPLIFIED TABLE 1 COHORT STRUCTURE")
    print("="*60)
    for sc in simplified_subcohorts:
        flag_count = enc[sc['flag']].sum()
        print(f"{sc['window']} {sc['name'].replace('_', ' ')}: n={flag_count}")

    print(f"\nTotal rows in simplified table: {len(final_simplified_table)}")
    print("Table columns:", list(final_simplified_table.columns))

else:
    print("No simplified tables were created!")

print(f"\n{'='*60}")
print("SIMPLIFIED TABLE 1 GENERATION COMPLETE")
print("="*60)

# (C) Summary by location type

This summarization looks at the first 72 hours of CRRT therapy for all patients in the cohort, including those didn't survive the first 24 hours of CRRT therapy. 

Questions:

1. How to handle ICU transfers? If someone started in Medical ICU but moved to Surgical ICU, which ICU do they represent?
    * Current code uses the "last" location during the 72h window

In [None]:
# ============================================================================
# TABLE 2: Characteristics by Location Type at 72 Hours Post-CRRT
# ============================================================================

def create_table2_by_location():
    """Create Table 2 stratified by ICU location type at 72 hours"""

    print("Creating Table 2: Characteristics by Location Type at 72h post-CRRT...")

    # Use the existing post72 window summary which already has location_type
    # This represents the location during the first 72 hours of CRRT
    location_data = post72.reset_index()

    # Get unique location types
    location_types = location_data['location_type'].dropna().unique()
    print(f"Location types found: {location_types}")

    # Create TableOne for each location type
    location_tables = {}
    location_counts = {}

    for loc_type in location_types:
        # Filter to this location type
        subset_data = location_data[location_data['location_type'] == loc_type]

        if len(subset_data) > 5:  # Minimum sample size
            # Set index back for TableOne
            subset_data = subset_data.set_index('encounter_block')

            # Create TableOne using existing function
            table = make_tableone(subset_data, f"Location_{loc_type}")
            location_tables[loc_type] = table
            location_counts[loc_type] = len(subset_data)
            print(f"  {loc_type}: {len(subset_data)} patients")
        else:
            print(f"  {loc_type}: {len(subset_data)} patients (too few - excluded)")

    return location_tables, location_counts

# ============================================================================
# COMBINE INTO TEMPLATE FORMAT
# ============================================================================

def create_table2_combined():
    """Combine location-stratified tables into the template format"""

    location_tables, location_counts = create_table2_by_location()

    if not location_tables:
        print("No location tables created!")
        return None

    # Extract data from each TableOne
    combined_dfs = []

    # Define the column order based on template
    template_columns = [
        'cardiac_icu', 'cvicu_icu', 'general_icu',
        'medical_icu', 'mixed_neuro_icu', 'surgical_icu'
    ]

    for loc_type in template_columns:
        if loc_type in location_tables:
            table = location_tables[loc_type]
            df = pd.DataFrame(table.tableone)

            # Get the "Overall" column (second column)
            if "Overall" in df.columns:
                summary_col = df[["Overall"]]
            else:
                summary_col = df.iloc[:, [1]]

            summary_col.columns = [loc_type]
            combined_dfs.append(summary_col)
        else:
            # Create empty column for missing location types
            if combined_dfs:  # Use existing structure
                empty_col = pd.DataFrame(
                    index=combined_dfs[0].index,
                    columns=[loc_type],
                    data=""
                )
                combined_dfs.append(empty_col)

    # Combine all location columns
    if combined_dfs:
        combined_table = combined_dfs[0]
        for df in combined_dfs[1:]:
            combined_table = combined_table.join(df, how='outer')

        # Add header row for template format
        header_row = pd.DataFrame(
            [["location_type at 72 hours post CRRT"] + [""] * (len(combined_table.columns) - 1)],
            columns=combined_table.columns,
            index=["Header"]
        )

        # Add count row
        count_data = []
        for col in combined_table.columns:
            if col in location_counts:
                count_data.append(str(location_counts[col]))
            else:
                count_data.append("")

        count_row = pd.DataFrame(
            [count_data],
            columns=combined_table.columns,
            index=["n"]
        )

        # Combine all parts
        final_table = pd.concat([header_row, combined_table, count_row])

        # Save the table
        output_file = f"{output_folder}/final/table2_by_location_type.csv"
        final_table.to_csv(output_file)
        print(f"Table 2 saved to: {output_file}")

        # Also create a cleaner version matching the exact template
        create_template_formatted_table2(combined_table, location_counts)

        return final_table

    return None

def create_template_formatted_table2(data_table, counts):
    """Create exact template format matching the reference"""

    # Create empty template structure
    template_cols = ['cardiac_icu', 'cvicu_icu', 'general_icu',
                    'medical_icu', 'mixed_neuro_icu', 'surgical_icu']

    # Initialize with level_0, level_1 columns like template
    template_df = pd.DataFrame(columns=['level_0', 'level_1'] + template_cols)

    # Add header row
    header_row = {
        'level_0': '',
        'level_1': '',
        **{col: col for col in template_cols}
    }
    template_df = pd.concat([template_df, pd.DataFrame([header_row])], ignore_index=True)

    # Add count row
    count_row = {
        'level_0': 'n',
        'level_1': '',
        **{col: counts.get(col, '') for col in template_cols}
    }
    template_df = pd.concat([template_df, pd.DataFrame([count_row])], ignore_index=True)

    # Process each row from the data table
    for idx in data_table.index:
        if isinstance(idx, tuple) and len(idx) >= 2:
            level_0, level_1 = idx[0], idx[1]
        else:
            level_0, level_1 = str(idx), ''

        # Create row data
        row_data = {
            'level_0': level_0,
            'level_1': level_1
        }

        # Add data for each location type
        for col in template_cols:
            if col in data_table.columns:
                row_data[col] = data_table.loc[idx, col]
            else:
                row_data[col] = ''

        template_df = pd.concat([template_df, pd.DataFrame([row_data])], ignore_index=True)

    # Save template-formatted version
    template_file = f"{output_folder}/final/table2_final.csv"
    template_df.to_csv(template_file, index=False)
    print(f"Template-formatted Table 2 saved to: {template_file}")

    return template_df



print("="*60)
print("GENERATING TABLE 2: CHARACTERISTICS BY LOCATION TYPE")
print("="*60)

table2_result = create_table2_combined()

print(f"\n{'='*60}")
print("TABLE 2 GENERATION COMPLETE")
print("="*60)
print("Files created:")
print("  - table2_by_location_type.csv")
print("  - table2_final.csv")

### C.1 Summary by location type by stats

1. Uses ANOVA (f_oneway) for continuous variables to test if means differ across ICU types
2. Uses Chi-square test for categorical variables to test if distributions differ across ICU types

In [None]:
# ============================================================================
# TABLE 2: Characteristics by Location Type with Statistical Tests
# ============================================================================

def create_table2_with_stats():
    """Create Table 2 with p-values comparing across ICU locations"""

    print("Creating Table 2 with statistical comparisons...")

    # Use the existing post72 window summary
    location_data = post72.reset_index()

    # Remove rows with missing location_type
    location_data = location_data.dropna(subset=['location_type'])

    # Get unique location types
    location_types = sorted(location_data['location_type'].unique())
    print(f"Location types found: {location_types}")

    # Define continuous and categorical variables - USE THE _median COLUMNS
    continuous_vars = [
        'age_at_admission',
        'bicarbonate_median', 'bun_median', 'calcium_total_median',
        'chloride_median', 'creatinine_median', 'glucose_serum_median', 'lactate_median',
        'magnesium_median', 'sodium_median', 'potassium_median', 'ph_arterial_median',
        'po2_arterial_median', 'fio2_set_median', 'peep_set_median',
        'blood_flow_rate_median', 'dialysate_flow_rate_median', 'ultrafiltration_out_median',
        'pre_filter_replacement_fluid_rate_median', 'post_filter_replacement_fluid_rate_median',
        'angiotensin_median', 'epinephrine_median', 'dopamine_median', 'vasopressin_median',
        'dobutamine_median', 'norepinephrine_median', 'phenylephrine_median',
        'sofa_total', 'sofa_cv_97', 'sofa_resp', 'sofa_renal', 'sofa_liver', 'sofa_coag', 'sofa_cns'
    ]

    categorical_vars = [
        'sex_category', 'race_category', 'ethnicity_category',
        'device_category', 'died_within_window', 'last_crrt_mode',
        'location_category'
    ]

    # Filter to available columns
    continuous_vars = [v for v in continuous_vars if v in location_data.columns]
    categorical_vars = [v for v in categorical_vars if v in location_data.columns]

    print(f"DEBUG: Found {len(continuous_vars)} continuous variables in data")
    print(f"DEBUG: Found {len(categorical_vars)} categorical variables in data")

    # Calculate p-values
    p_values = {}

    # ANOVA for continuous variables
    from scipy import stats
    print(f"\nDEBUG: Testing ANOVA for continuous variables...")
    for var in continuous_vars:
        groups = []
        group_sizes = []
        for loc_type in location_types:
            group_data = location_data[location_data['location_type'] == loc_type][var].dropna()
            if len(group_data) > 0:
                groups.append(group_data)
                group_sizes.append(len(group_data))

        if len(groups) >= 2:  # Need at least 2 groups for comparison
            try:
                _, p_val = stats.f_oneway(*groups)
                # Store p-value with base name (remove _median suffix) for matching with TableOne
                base_var_name = var.replace('_median', '')
                p_values[base_var_name] = p_val
                print(f"  {var} -> {base_var_name}: p = {p_val:.6f}")
            except Exception as e:
                base_var_name = var.replace('_median', '')
                p_values[base_var_name] = np.nan
                print(f"  {var} -> ERROR: {e}")
        else:
            base_var_name = var.replace('_median', '')
            p_values[base_var_name] = np.nan

    # Chi-square for categorical variables (unchanged)
    for var in categorical_vars:
        try:
            crosstab = pd.crosstab(location_data[var], location_data['location_type'])
            if crosstab.shape[0] > 1 and crosstab.shape[1] > 1:
                chi2, p_val, dof, expected = stats.chi2_contingency(crosstab)
                p_values[var] = p_val
                print(f"  {var}: p = {p_val:.6f}")
            else:
                p_values[var] = np.nan
        except Exception as e:
            p_values[var] = np.nan
            print(f"  {var}: ERROR: {e}")

    print(f"\nDEBUG: Total p-values calculated: {len([p for p in p_values.values() if pd.notna(p)])}")

    # Rest of the function remains the same...
    # Create TableOne for each location type
    location_tables = {}
    location_counts = {}

    for loc_type in location_types:
        subset_data = location_data[location_data['location_type'] == loc_type]

        if len(subset_data) > 5:  # Minimum sample size
            subset_data = subset_data.set_index('encounter_block')
            table = make_tableone(subset_data, f"Location_{loc_type}")
            location_tables[loc_type] = table
            location_counts[loc_type] = len(subset_data)
            print(f"  {loc_type}: {len(subset_data)} patients")

    # Combine tables with p-values
    combined_table = combine_location_tables_with_pvalues(
        location_tables, location_counts, p_values, continuous_vars, categorical_vars
    )

    return combined_table, p_values

def combine_location_tables_with_pvalues(location_tables, location_counts, p_values,
                                       continuous_vars, categorical_vars):
    """Combine location tables and add p-value column"""

    if not location_tables:
        return None

    # Extract data from each TableOne
    combined_dfs = []

    # Define the column order based on template
    template_columns = [
        'cardiac_icu', 'cvicu_icu', 'general_icu',
        'medical_icu', 'mixed_neuro_icu', 'surgical_icu'
    ]

    # First, collect all data
    all_data = {}
    for loc_type in template_columns:
        if loc_type in location_tables:
            table = location_tables[loc_type]
            df = pd.DataFrame(table.tableone)

            if "Overall" in df.columns:
                all_data[loc_type] = df["Overall"]
            else:
                all_data[loc_type] = df.iloc[:, 1]

    if not all_data:
        return None

    # Create combined dataframe
    combined_df = pd.DataFrame(all_data)

    # DEBUG: Print what we have
    print("\nDEBUG: P-values available for:")
    for var, p in p_values.items():
        if pd.notna(p):
            print(f"  {var}: {p:.4f}")

    print("\nDEBUG: Table indices (first 20):")
    for i, idx in enumerate(combined_df.index[:20]):
        print(f"  {i}: {idx}")

    # Add p-values column with better matching
    p_value_col = []
    matched_vars = []
    unmatched_vars = []

    for idx in combined_df.index:
        # Extract variable name from index
        if isinstance(idx, tuple):
            var_name = idx[0]
        else:
            var_name = str(idx)

        # Clean variable name (remove suffixes like ', median [Q1,Q3]')
        base_var = var_name.split(',')[0].strip()

        # Find p-value
        if base_var in p_values:
            p_val = p_values[base_var]
            matched_vars.append(base_var)
            if pd.notna(p_val):
                if p_val < 0.001:
                    p_value_col.append("<0.001")
                else:
                    p_value_col.append(f"{p_val:.3f}")
            else:
                p_value_col.append("")
        else:
            p_value_col.append("")
            if base_var != 'n':  # Don't report 'n' as unmatched
                unmatched_vars.append((base_var, var_name))

    # DEBUG: Report matching results
    print(f"\nDEBUG: Successfully matched {len(set(matched_vars))} variables")
    print(f"DEBUG: Failed to match {len(unmatched_vars)} variables:")
    for base, full in unmatched_vars[:10]:  # Show first 10
        print(f"  Base: '{base}' from Full: '{full}'")

    combined_df['p_value'] = p_value_col

    # Add count row
    count_data = [location_counts.get(col, "") for col in template_columns if col in combined_df.columns]
    count_data.append("")  # Empty p-value for count row

    count_row = pd.DataFrame(
        [count_data],
        columns=list(combined_df.columns),
        index=["n"]
    )

    # Combine all parts
    final_table = pd.concat([combined_df, count_row])

    # Save the table
    output_file = f"{output_folder}/final/table2_by_location_with_stats_{pyCLIF.helper['site_name'].lower()}.csv"
    final_table.to_csv(output_file)
    print(f"Table 2 with p-values saved to: {output_file}")

    # Also save p-values separately for reference
    p_values_df = pd.DataFrame(list(p_values.items()), columns=['Variable', 'p_value'])
    p_values_df = p_values_df.sort_values('p_value')
    p_values_df.to_csv(f"{output_folder}/final/table2_p_values.csv", index=False)

    return final_table

# Run the analysis
table2_with_stats, p_values_dict = create_table2_with_stats()

# Display significant findings
print("\nStatistically significant differences (p < 0.05):")
for var, p_val in sorted(p_values_dict.items(), key=lambda x: x[1] if pd.notna(x[1]) else 1):
    if pd.notna(p_val) and p_val < 0.05:
        print(f"  {var}: p = {p_val:.4f}")

# (D) CRRT summary by modes

In [None]:
# Get first settings for each patient
first_settings = crrt_df.sort_values(['encounter_block', 'recorded_dttm']).groupby(['encounter_block', 'crrt_mode_category']).first()

# Get numeric columns for analysis
numeric_cols = [
    'blood_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate', 
    'dialysate_flow_rate',
    'ultrafiltration_out'
]

# Calculate median and IQR of first settings by mode
first_settings_summary = first_settings[numeric_cols].groupby('crrt_mode_category').agg([
    'median',
    lambda x: x.quantile(0.25),
    lambda x: x.quantile(0.75)
])

first_settings_summary.columns = first_settings_summary.columns.map(
    lambda x: f"{x[0]}_{x[1]}" if x[1] == 'median' 
    else f"{x[0]}_q1" if x[1] == '<lambda_0>'
    else f"{x[0]}_q3"
)

# Calculate mean and std directly from all measurements
mode_avg_settings = crrt_df.groupby('crrt_mode_category')[numeric_cols].agg(['mean', 'std'])
mode_avg_settings.columns = mode_avg_settings.columns.map(lambda x: f"{x[0]}_{x[1]}")

# Calculate patient-level summaries
# First calculate patient averages
patient_avg_settings = crrt_df.groupby(['encounter_block', 'crrt_mode_category'])[numeric_cols].mean()

# Then get median and IQR of patient averages
patient_summary = patient_avg_settings.groupby('crrt_mode_category').agg([
    'median',
    lambda x: x.quantile(0.25),
    lambda x: x.quantile(0.75)
])

patient_summary.columns = patient_summary.columns.map(
    lambda x: f"{x[0]}_patient_{x[1]}" if x[1] == 'median'
    else f"{x[0]}_patient_q1" if x[1] == '<lambda_0>'
    else f"{x[0]}_patient_q3"
)

# Calculate duration for each mode
crrt_df['next_time'] = crrt_df.groupby('encounter_block')['recorded_dttm'].shift(-1)
crrt_df['duration_hrs'] = (crrt_df['next_time'] - crrt_df['recorded_dttm']).dt.total_seconds() / 3600

mode_duration = crrt_df.groupby(['encounter_block', 'crrt_mode_category'])['duration_hrs'].sum()
duration_summary = mode_duration.groupby('crrt_mode_category').agg([
    'mean',
    'std',
    'median',
    lambda x: x.quantile(0.25),
    lambda x: x.quantile(0.75)
])

duration_summary.columns = [
    'duration_mean_hrs',
    'duration_std_hrs',
    'duration_median_hrs',
    'duration_q1_hrs',
    'duration_q3_hrs'
]

# Save results
first_settings_summary.to_csv(f'{output_folder}/final/crrt_first_settings.csv')
mode_avg_settings.to_csv(f'{output_folder}/final/crrt_average_settings.csv')
patient_summary.to_csv(f'{output_folder}/final/crrt_patient_averages.csv')
duration_summary.to_csv(f'{output_folder}/final/crrt_duration.csv')

In [None]:
# Step 1) Build the display DataFrame of formatted strings
params = [
    "blood_flow_rate",
    "pre_filter_replacement_fluid_rate", 
    "post_filter_replacement_fluid_rate",
    "dialysate_flow_rate",
    "ultrafiltration_out"
]

# This will be our new table: rows = modes, cols = params
display_df = pd.DataFrame(index=mode_avg_settings.index)

for p in params:
    mean = mode_avg_settings[f"{p}_mean"]
    std = mode_avg_settings[f"{p}_std"]
    
    # format each row as "mean ± std" 
    display_df[p] = [
        f"{int(m):,}  ({int(s):,})" if pd.notna(m) and pd.notna(s) else "NA"
        if pd.notna(m)
        else "NA"
        for m, s in zip(mean, std)
    ]

# optional: rename columns to more human-friendly labels
display_df.columns = [
    "Blood flow (mL/hr)",
    "Pre-filter repl (mL/hr)", 
    "Post-filter repl (mL/hr)",
    "Dialysate flow (mL/hr)",
    "Ultrafiltration (mL/hr)"
]

# Step 2) Draw with matplotlib.table
fig, ax = plt.subplots(figsize=(10, 2 + 0.5 * len(display_df)))  # height scales with rows
ax.axis("off")

tbl = ax.table(
    cellText=display_df.values,
    rowLabels=display_df.index,
    colLabels=display_df.columns,
    cellLoc="center",
    rowLoc="center",
    loc="center"
)

tbl.auto_set_font_size(False)
tbl.set_fontsize(10)
tbl.scale(1, 1.5)  # stretch rows a bit

plt.title(f"Average CRRT Settings by CRRT Mode in {pyCLIF.helper['site_name']}: Mean (SD)", pad=20)
plt.tight_layout()
# Save the figure
plt.savefig(f"../output/final/graphs/avg_crrt_settings.png", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Add index as column for melt
duration_summary_reset = duration_summary.reset_index()

# Plot Duration on each CRRT mode
plt.figure(figsize=(10, 6))
sns.barplot(data=duration_summary_reset, 
            x='crrt_mode_category', y='duration_median_hrs', 
            color='skyblue')
plt.errorbar(x=np.arange(len(duration_summary_reset)), 
             y=duration_summary_reset['duration_median_hrs'], 
             yerr=[duration_summary_reset['duration_median_hrs'] - duration_summary_reset['duration_q1_hrs'], 
                   duration_summary_reset['duration_q3_hrs'] - duration_summary_reset['duration_median_hrs']], 
             fmt='none', c='black', capsize=5)
plt.title("Duration on CRRT Mode (Median and IQR)")
plt.ylabel("Hours")
plt.xticks(rotation=45)
plt.tight_layout()
# Save the figure before showing and closing
plt.savefig(f"{output_dir}/graphs/crrt_mode_duration.png")
plt.show()
plt.close()


# (E) CRRT Mode Switches

In [None]:
# Analyze mode switches by looking at consecutive rows with different modes
crrt_df['prev_mode'] = crrt_df.groupby('encounter_block')['crrt_mode_category'].shift(1)
crrt_df['mode_switch'] = (
    (crrt_df['crrt_mode_category'] != crrt_df['prev_mode']) &
    (~crrt_df['prev_mode'].isna()) &
    (~crrt_df['crrt_mode_category'].isna())  # ← exclude transitions TO NaN
)

# Count switches between each mode pair
mode_switches = crrt_df[crrt_df['mode_switch']].groupby(['prev_mode', 'crrt_mode_category']).size()
mode_switches = mode_switches.reset_index()
mode_switches.columns = ['from_mode', 'to_mode', 'count']

# Create pivot table for display
mode_switches_pivot = mode_switches.pivot(index='from_mode', columns='to_mode', values='count').fillna(0)

print("\nCRRT Mode Switches:")
print(mode_switches_pivot)

# Visualize mode switches as a heatmap
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 8))
sns.heatmap(mode_switches_pivot, 
            annot=True, 
            fmt='.0f',
            cmap='Blues',
            cbar_kws={'label': f"Number of Mode Switches {pyCLIF.helper['site_name']}"})
plt.title('CRRT Mode Mode Switch Matrix')
plt.xlabel('To Mode')
plt.ylabel('From Mode')
plt.tight_layout()
plt.savefig(f"{output_dir}/graphs/crrt_mode_transitions_heatmap.png")
plt.show()
plt.close()

# Calculate average number of mode switches per hospitalization
switches_per_hosp = crrt_df.groupby('encounter_block')['mode_switch'].sum()
avg_switches = switches_per_hosp.mean()
median_switches = switches_per_hosp.median()

print(f"\nAverage mode switches per hospitalization: {avg_switches:.2f}")
print(f"Median mode switches per hospitalization: {median_switches:.2f}")

# Save mode switches matrix
mode_switches_pivot.to_csv(f"{output_dir}/crrt_mode_switches.csv")

# Save summary statistics
with open(f"{output_dir}/crrt_mode_switches_summary.txt", "w") as f:
    f.write(f"Average mode switches per hospitalization: {avg_switches:.2f}\n")
    f.write(f"Median mode switches per hospitalization: {median_switches:.2f}\n")

In [None]:
# Ensure mode switches are correctly flagged
crrt_df['prev_mode'] = crrt_df.groupby('encounter_block')['crrt_mode_category'].shift(1)
crrt_df['mode_switch'] = (
    crrt_df['crrt_mode_category'] != crrt_df['prev_mode']
) & (~crrt_df['prev_mode'].isna())

# Count unique encounters where a switch occurred between each mode pair
mode_switch_encounters = (
    crrt_df[crrt_df['mode_switch']]
    .groupby(['prev_mode', 'crrt_mode_category'])['encounter_block']
    .nunique()
    .reset_index(name='encounter_count')
)

# Pivot for matrix view
switch_matrix = mode_switch_encounters.pivot(
    index='prev_mode',
    columns='crrt_mode_category',
    values='encounter_count'
).fillna(0).astype(int)

print("\nUnique Encounters with CRRT Mode Switches:")
print(switch_matrix)

# Save to CSV
switch_matrix.to_csv("../output/final/unique_encounter_mode_switches.csv")

# Create heatmap for mode switches by encounter
plt.figure(figsize=(10, 8))
sns.heatmap(switch_matrix,
            annot=True,
            fmt='d',
            cmap='Blues',
            cbar_kws={'label': 'Number of Unique Encounters with Mode Switches'})
plt.title(f"CRRT Mode Switches by Unique Encounters {pyCLIF.helper['site_name']}")
plt.xlabel('To Mode')
plt.ylabel('From Mode')
plt.tight_layout()
plt.savefig(f"{output_dir}/graphs/crrt_mode_transitions_by_encounter_heatmap.png")
plt.show()
plt.close()

In [None]:
# ============================================================================
# CRRT MODE SWITCHES WITH LOCATION CONTEXT
# ============================================================================

def analyze_mode_switches_with_location():
    """Analyze CRRT mode switches with location context"""

    print("Analyzing CRRT mode switches with location context...")

    crrt_with_location = crrt_df.merge(
        clif_wide_df[['encounter_block', 'recorded_dttm', 'location_type']]
        .drop_duplicates(subset=['encounter_block', 'recorded_dttm']),
        on=['encounter_block', 'recorded_dttm'],
        how='left'
    )

    print(f"Matched {len(crrt_with_location)} CRRT records with location data")
    print(f"Location coverage: {(~crrt_with_location['location_type'].isna()).mean()*100:.1f}%")

    crrt_with_location = crrt_with_location.sort_values(['encounter_block', 'recorded_dttm'])
    crrt_with_location['prev_mode'] = crrt_with_location.groupby('encounter_block')['crrt_mode_category'].shift(1)
    crrt_with_location['mode_switch'] = (
        (crrt_with_location['crrt_mode_category'] != crrt_with_location['prev_mode']) &
        (~crrt_with_location['prev_mode'].isna())
    )

    mode_switches_with_location = crrt_with_location[crrt_with_location['mode_switch']].copy()

    print(f"Found {len(mode_switches_with_location)} mode switches")
    print(f"Unique encounters with switches: {mode_switches_with_location['encounter_block'].nunique()}")

    location_transition_matrices = {}

    for location in mode_switches_with_location['location_type'].dropna().unique():
        location_switches = mode_switches_with_location[
            mode_switches_with_location['location_type'] == location
        ]

        if len(location_switches) > 0:
            transition_counts = (
                location_switches
                .groupby(['prev_mode', 'crrt_mode_category'])['encounter_block']
                .nunique()
                .reset_index(name='encounter_count')
            )

            transition_matrix = transition_counts.pivot(
                index='prev_mode',
                columns='crrt_mode_category',
                values='encounter_count'
            ).fillna(0).astype(int)

            location_transition_matrices[location] = transition_matrix

            unique_encounters = location_switches['encounter_block'].nunique()
            print(f"\n{location}: {len(location_switches)} switches, {unique_encounters} unique encounters")
            print(transition_matrix)

    return mode_switches_with_location, location_transition_matrices

# ============================================================================
# CREATE LOCATION-STRATIFIED TRANSITION ANALYSIS
# ============================================================================

def create_comprehensive_transition_analysis():
    """Create comprehensive analysis of mode transitions by location"""

    switches_with_location, location_matrices = analyze_mode_switches_with_location()

    location_summary = []

    print("\nMode Switch Summary by Location:")
    print("="*50)

    for location in switches_with_location['location_type'].dropna().unique():
        location_data = switches_with_location[switches_with_location['location_type'] == location]

        if len(location_data) > 0:
            unique_encounters = location_data['encounter_block'].nunique()
            total_switches = len(location_data)

            from_modes = location_data['prev_mode'].value_counts().to_dict()
            to_modes = location_data['crrt_mode_category'].value_counts().to_dict()

            most_common_from = max(from_modes, key=from_modes.get) if from_modes else "None"
            most_common_to = max(to_modes, key=to_modes.get) if to_modes else "None"

            print(f"\n{location}:")
            print(f"  Unique encounters with switches: {unique_encounters}")
            print(f"  Total switch events: {total_switches}")
            print(f"  Most common FROM mode: {most_common_from}")
            print(f"  Most common TO mode: {most_common_to}")

            location_summary.append({
                'location': location,
                'unique_encounters': unique_encounters,
                'total_switches': total_switches,
                'most_common_from': most_common_from,
                'most_common_to': most_common_to
            })

    location_summary_df = pd.DataFrame(location_summary)

    detailed_results = {}

    for location, matrix in location_matrices.items():
        total_encounters_with_switches = matrix.sum().sum()

        detailed_table = []
        for from_mode in matrix.index:
            for to_mode in matrix.columns:
                count = matrix.loc[from_mode, to_mode]
                if count > 0:
                    pct_of_total = (count / total_encounters_with_switches) * 100
                    detailed_table.append({
                        'location': location,
                        'from_mode': from_mode,
                        'to_mode': to_mode,
                        'encounter_count': count,
                        'pct_of_location_encounters': pct_of_total
                    })

        detailed_results[location] = pd.DataFrame(detailed_table)

    switches_with_location.to_csv(
        f"{output_folder}/intermediate/crrt_switches_with_location.csv",
        index=False
    )

    location_summary_df.to_csv(
        f"{output_folder}/final/switch_summary_by_location.csv",
        index=False
    )

    for location, matrix in location_matrices.items():
        safe_filename = location.replace('/', '_').replace(' ', '_')
        matrix.to_csv(
            f"{output_folder}/final/transition_matrix_{safe_filename}.csv"
        )

    if detailed_results:
        all_detailed = pd.concat(detailed_results.values(), ignore_index=True)
        all_detailed.to_csv(
            f"{output_folder}/final/detailed_transitions_by_location.csv",
            index=False
        )
    else:
        all_detailed = pd.DataFrame()

    return switches_with_location, location_matrices, all_detailed

# ============================================================================
# VISUALIZATION OF LOCATION-SPECIFIC TRANSITIONS
# ============================================================================

def create_location_transition_visualizations():
    """Create visualizations for mode transitions by location"""

    switches_data, matrices, detailed = create_comprehensive_transition_analysis()

    if not matrices:
        print("No transition matrices to visualize!")
        return detailed

    n_locations = len(matrices)
    n_cols = min(3, n_locations)
    n_rows = (n_locations + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

    if n_locations == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = [axes] if n_locations == 1 else axes
    else:
        axes = axes.flatten()

    for idx, (location, matrix) in enumerate(matrices.items()):
        ax = axes[idx]

        sns.heatmap(matrix,
                    annot=True,
                    fmt='d',
                    cmap='Blues',
                    ax=ax,
                    cbar=True)

        total_encounters = matrix.sum().sum()
        ax.set_title(f'{location}\n({total_encounters} encounters with switches)')
        ax.set_xlabel('To Mode')
        ax.set_ylabel('From Mode')

    for idx in range(n_locations, len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.savefig(f"{output_folder}/final/graphs/transition_matrices_by_location.png",
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    if len(switches_data) > 0:
        encounter_counts_by_location = (
            switches_data.groupby('location_type')['encounter_block']
            .nunique()
            .sort_values(ascending=False)
        )

        plt.figure(figsize=(12, 6))
        bars = plt.bar(range(len(encounter_counts_by_location)), encounter_counts_by_location.values)
        plt.xlabel('ICU Location Type')
        plt.ylabel('Number of Encounters with Mode Switches')
        plt.title('CRRT Mode Switches by ICU Location (Unique Encounters)')
        plt.xticks(range(len(encounter_counts_by_location)),
                    encounter_counts_by_location.index, rotation=45, ha='right')

        for bar, count in zip(bars, encounter_counts_by_location.values):
            plt.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                    f'{count}', ha='center', va='bottom')

        plt.tight_layout()
        plt.savefig(f"{output_folder}/final/graphs/mode_switches_by_location_summary.png",
                    dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()

    return detailed

# ============================================================================
# RUN THE ANALYSIS
# ============================================================================

print("="*60)
print("ANALYZING CRRT MODE SWITCHES WITH LOCATION CONTEXT")
print("="*60)

detailed_transitions = create_location_transition_visualizations()

print(f"\n{'='*60}")
print("TRANSITION ANALYSIS COMPLETE")
print("="*60)
print("Files created:")
print("  - crrt_switches_with_location.csv (raw data)")
print("  - switch_summary_by_location.csv (summary by location)")
print("  - transition_matrix_[location].csv (one per ICU type)")
print("  - detailed_transitions_by_location.csv (detailed results)")
print("  - transition_matrices_by_location.png (heatmaps)")
print("  - mode_switches_by_location_summary.png (bar chart)")

if len(detailed_transitions) > 0:
    print(f"\nSample of detailed transitions:")
    print(detailed_transitions.head(10))
else:
    print("\nNo detailed transitions to display")

# (F) Hourly Trends

In [None]:
# 1) Build a relative‐hour field where hour=0 is 24h before CRRT
df = (
    clif_wide_df
    .merge(first_crrt, on="encounter_block", how="inner")
)
# origin = first_crrt_time - 24h
df["origin"] = df["first_crrt_time"] - pd.Timedelta(hours=24)
df["rel_hr"] = ((df["recorded_dttm"] - df["origin"]) 
                / pd.Timedelta(hours=1))
# filter to [0, 96]
df = df[(df["rel_hr"] >= 0) & (df["rel_hr"] <= 96)]

# assign integer hour bins
df["hour_bin"] = df["rel_hr"].floordiv(1).astype(int)

# 2) Compute per‐hour median, Q1, Q3 for each variable
agg = {}
for v in continuous_vars:
    agg[v] = ["median", lambda x: x.quantile(0.25), lambda x: x.quantile(0.75)]
hourly = df.groupby("hour_bin").agg(agg)
# flatten columns
hourly.columns = pd.Index([f"{var}_{stat}"
                           for var,stat in hourly.columns],
                          name=None)

# 3) Plot grid of time‐series with shaded IQR and CRRT‐band
n = len(continuous_vars)
cols = 4
rows = (n + cols - 1)//cols
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 2.5*rows), sharex=True)
for ax, var in zip(axes.flat, continuous_vars):
    m   = hourly[f"{var}_median"]
    q1  = hourly[f"{var}_<lambda_0>"]  # this is the 0.25 quantile
    q3  = hourly[f"{var}_<lambda_1>"]  # this is the 0.75 quantile

    ax.plot(hourly.index, m)
    ax.fill_between(hourly.index, q1, q3, alpha=0.3)
    # highlight the first 24h of CRRT → that's rel_hr 24→48
    ax.axvspan(24, 48, color="C1", alpha=0.1)
    ax.set_title(f"{var} [median, IQR]")
    ax.set_xlim(0,96)
    ax.set_xlabel("Hours")
    ax.set_ylabel(var)

# hide any empty subplots
for extra_ax in axes.flat[n:]:
    extra_ax.set_visible(False)

fig.tight_layout()
plt.suptitle("Trajectories of ICU Labs/Vasos/Vent Settings from CRRT-24hrs to CRRT+72hrs (Median[IQR])", y=1.02)
plt.savefig("../output/final/graphs/trajectories_median_iqr.png", bbox_inches='tight', dpi=300)
plt.show()

# save hourly data
hourly = hourly.rename(
    columns={f"{v}_<lambda_0>": f"{v}_q1",
             f"{v}_<lambda_1>": f"{v}_q3"}
)

# 2) reset_index so `hour_bin` becomes a column again
site_summary = hourly.reset_index().rename(columns={"hour_bin": "hour"})
site_summary.to_csv(os.path.join(output_folder, "final", "site_hourly_summary_median.csv"), index=False)

In [None]:
# 1) Build a relative‐hour field where hour=0 is 24h before CRRT
df = (
    clif_wide_df
    .merge(first_crrt, on="encounter_block", how="inner")
)
# origin = first_crrt_time - 24h
df["origin"] = df["first_crrt_time"] - pd.Timedelta(hours=24)
df["rel_hr"] = ((df["recorded_dttm"] - df["origin"]) 
                / pd.Timedelta(hours=1))
# filter to [0, 96]
df = df[(df["rel_hr"] >= 0) & (df["rel_hr"] <= 96)]

# assign integer hour bins
df["hour_bin"] = df["rel_hr"].floordiv(1).astype(int)

# 2) Compute per‐hour mean and std for each variable
agg = {}
for v in continuous_vars:
    agg[v] = ["mean", "std"]
hourly = df.groupby("hour_bin").agg(agg)
# flatten columns
hourly.columns = pd.Index([f"{var}_{stat}"
                           for var,stat in hourly.columns],
                          name=None)

# 3) Plot grid of time‐series with shaded ±1 SD and CRRT‐band
n = len(continuous_vars)
cols = 4
rows = (n + cols - 1)//cols
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 2.5*rows), sharex=True)
for ax, var in zip(axes.flat, continuous_vars):
    mean = hourly[f"{var}_mean"]
    sd = hourly[f"{var}_std"]

    ax.plot(hourly.index, mean)
    ax.fill_between(hourly.index, mean-sd, mean+sd, alpha=0.3)
    # highlight the first 24h of CRRT → that's rel_hr 24→48
    ax.axvspan(24, 48, color="C1", alpha=0.1)
    ax.set_title(f"{var} [mean ± SD]")
    ax.set_xlim(0,96)
    ax.set_xlabel("Hours")
    ax.set_ylabel(var)

# hide any empty subplots
for extra_ax in axes.flat[n:]:
    extra_ax.set_visible(False)

fig.tight_layout()
plt.suptitle("Trajectories of ICU Labs/Vasos/Vent Settings from CRRT-24hrs to CRRT+72hrs (Mean ± SD)", y=1.02)
plt.savefig("../output/final/graphs/trajectories_mean_sd.png", bbox_inches='tight', dpi=300)
plt.show()

# save hourly data
hourly = hourly.rename(
    columns={f"{v}_<lambda_0>": f"{v}_q1",
             f"{v}_<lambda_1>": f"{v}_q3"}
)

# 2) reset_index so `hour_bin` becomes a column again
site_summary = hourly.reset_index().rename(columns={"hour_bin": "hour"})
site_summary.to_csv(os.path.join(output_folder, "final", "site_hourly_summary_mean.csv"), index=False)

# (G) Analysis

In [None]:
# ============================================================================
# SECTION G: Chi-Square Statistical Analysis
# ============================================================================

import scipy.stats as stats
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency
import json
import logging
import os

# Set up logging to both file and console
log_dir = "../output/final"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "statistical_analysis.log")

# Create logger
logger = logging.getLogger("statistical_analysis")
logger.setLevel(logging.INFO)

# Create file handler
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

# Create console handler that prints to notebook
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)

# Add handlers to logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)

logger.info("=== CHI-SQUARE ANALYSIS FOR CRRT COHORTS ===")

# Use existing cohort flags from Section C - no need to recreate them!
logger.info("\n=== COHORT SIZES (from Section C) ===")
logger.info(f"surv_oncrrt_post24: {enc['surv_oncrrt_post24'].sum()} patients")
logger.info(f"surv_offcrrt_post24: {enc['surv_offcrrt_post24'].sum()} patients")
logger.info(f"nonsurv_oncrrt_post24: {enc['non_surv_oncrrt_post24'].sum()} patients")
logger.info(f"nonsurv_offcrrt_post24: {enc['non_surv_offcrrt_post24'].sum()} patients")
logger.info(f"surv_oncrrt_post72: {enc['surv_oncrrt_post72'].sum()} patients")
logger.info(f"surv_offcrrt_post72: {enc['surv_offcrrt_post72'].sum()} patients")
logger.info(f"nonsurv_oncrrt_post72: {enc['non_surv_oncrrt_post72'].sum()} patients")
logger.info(f"nonsurv_offcrrt_post72: {enc['non_surv_offcrrt_post72'].sum()} patients")

# ============================================================================
# 1. CRRT LIBERATION vs MORTALITY ANALYSIS
# ============================================================================

logger.info("\n" + "="*60)
logger.info("1. CRRT LIBERATION vs MORTALITY ANALYSIS")
logger.info("="*60)

def chi_square_liberation_mortality():
    """Test if CRRT liberation is associated with survival"""

    results = {}

    # 24-hour analysis
    logger.info("\n--- 24-Hour Analysis ---")

    # Create contingency table using existing cohort flags
    contingency_24h = pd.DataFrame({
        'On_CRRT': [enc['surv_oncrrt_post24'].sum(), enc['non_surv_oncrrt_post24'].sum()],
        'Off_CRRT': [enc['surv_offcrrt_post24'].sum(), enc['non_surv_offcrrt_post24'].sum()]
    }, index=['Survived', 'Died'])

    logger.info("\nContingency Table (24h):")
    logger.info(str(contingency_24h))

    # Perform chi-square test
    chi2_24h, p_24h, dof_24h, expected_24h = chi2_contingency(contingency_24h)

    logger.info(f"\nChi-square statistic: {chi2_24h:.4f}")
    logger.info(f"p-value: {p_24h:.6f}")
    logger.info(f"Degrees of freedom: {dof_24h}")

    # Calculate effect size (Cramer's V)
    n_24h = contingency_24h.sum().sum()
    cramers_v_24h = np.sqrt(chi2_24h / (n_24h * (min(contingency_24h.shape) - 1)))
    logger.info(f"Cramer's V (effect size): {cramers_v_24h:.4f}")

    results['24h'] = {
        'contingency_table': contingency_24h.to_dict(),
        'chi2_stat': chi2_24h,
        'p_value': p_24h,
        'cramers_v': cramers_v_24h,
        'n_total': int(n_24h)
    }

    # 72-hour analysis (survivors of 24h only)
    logger.info("\n--- 72-Hour Analysis (24h survivors only) ---")

    contingency_72h = pd.DataFrame({
        'On_CRRT': [enc['surv_oncrrt_post72'].sum(), enc['non_surv_oncrrt_post72'].sum()],
        'Off_CRRT': [enc['surv_offcrrt_post72'].sum(), enc['non_surv_offcrrt_post72'].sum()]
    }, index=['Survived', 'Died'])

    logger.info("\nContingency Table (72h):")
    logger.info(str(contingency_72h))

    chi2_72h, p_72h, dof_72h, expected_72h = chi2_contingency(contingency_72h)

    logger.info(f"\nChi-square statistic: {chi2_72h:.4f}")
    logger.info(f"p-value: {p_72h:.6f}")
    logger.info(f"Degrees of freedom: {dof_72h}")

    n_72h = contingency_72h.sum().sum()
    cramers_v_72h = np.sqrt(chi2_72h / (n_72h * (min(contingency_72h.shape) - 1)))
    logger.info(f"Cramer's V (effect size): {cramers_v_72h:.4f}")

    results['72h'] = {
        'contingency_table': contingency_72h.to_dict(),
        'chi2_stat': chi2_72h,
        'p_value': p_72h,
        'cramers_v': cramers_v_72h,
        'n_total': int(n_72h)
    }

    return results

liberation_mortality_results = chi_square_liberation_mortality()

# ============================================================================
# 2. ICU LOCATION vs OUTCOMES ANALYSIS
# ============================================================================

logger.info("\n" + "="*60)
logger.info("2. ICU LOCATION vs OUTCOMES ANALYSIS")
logger.info("="*60)

def chi_square_location_outcomes():
    """Test if ICU location is associated with outcomes"""

    results = {}

    # Load the combined summary data to get location info
    combined_summary = pd.read_csv('../output/intermediate/combined_summary.csv')

    # Get location info from the post72 window (this has location_type)
    post72_df = combined_summary[combined_summary['window'] == 'Post-72h'].set_index('encounter_block')

    # Merge location info with our encounter-level data
    enc_with_location = enc.merge(
        post72_df[['location_type']],
        left_on='encounter_block',
        right_index=True,
        how='left'
    )

    logger.info("\n--- Location Distribution Analysis ---")
    logger.info(f"Patients with location data: {(~enc_with_location['location_type'].isna()).sum()}")
    logger.info(f"Location types: {enc_with_location['location_type'].dropna().unique()}")

    # 24-hour mortality by location
    logger.info("\n24-Hour Mortality by ICU Location:")

    # Create contingency table: Location vs 24h mortality
    location_24h_mortality = pd.crosstab(
        enc_with_location['location_type'],
        enc_with_location['died_within_24h'],
        margins=True
    )

    logger.info(str(location_24h_mortality))

    # Remove 'All' row/column for chi-square test
    contingency_loc_24h = location_24h_mortality.iloc[:-1, :-1]

    # Only include locations with sufficient sample size (n>10)
    valid_locations = contingency_loc_24h.index[contingency_loc_24h.sum(axis=1) >= 10]
    contingency_loc_24h_filtered = contingency_loc_24h.loc[valid_locations]

    logger.info(f"\nFiltered contingency table (locations with n≥10):")
    logger.info(str(contingency_loc_24h_filtered))

    if len(contingency_loc_24h_filtered) > 1 and contingency_loc_24h_filtered.sum().sum() > 0:
        chi2_loc_24h, p_loc_24h, dof_loc_24h, expected_loc_24h = chi2_contingency(contingency_loc_24h_filtered)

        logger.info(f"\nChi-square statistic: {chi2_loc_24h:.4f}")
        logger.info(f"p-value: {p_loc_24h:.6f}")
        logger.info(f"Degrees of freedom: {dof_loc_24h}")

        # Cramer's V
        n_loc_24h = contingency_loc_24h_filtered.sum().sum()
        cramers_v_loc_24h = np.sqrt(chi2_loc_24h / (n_loc_24h * (min(contingency_loc_24h_filtered.shape) - 1)))
        logger.info(f"Cramer's V (effect size): {cramers_v_loc_24h:.4f}")

        results['location_24h_mortality'] = {
            'contingency_table': contingency_loc_24h_filtered.to_dict(),
            'chi2_stat': chi2_loc_24h,
            'p_value': p_loc_24h,
            'cramers_v': cramers_v_loc_24h,
            'n_total': int(n_loc_24h)
        }

    # 24-hour CRRT liberation by location
    logger.info("\n24-Hour CRRT Liberation by ICU Location:")

    location_24h_liberation = pd.crosstab(
        enc_with_location['location_type'],
        ~enc_with_location['on_crrt_at_24h'],  # Liberation = NOT on CRRT
        margins=True
    )

    logger.info(str(location_24h_liberation))

    contingency_lib_24h = location_24h_liberation.iloc[:-1, :-1]
    contingency_lib_24h_filtered = contingency_lib_24h.loc[valid_locations]

    if len(contingency_lib_24h_filtered) > 1 and contingency_lib_24h_filtered.sum().sum() > 0:
        chi2_lib_24h, p_lib_24h, dof_lib_24h, expected_lib_24h = chi2_contingency(contingency_lib_24h_filtered)

        logger.info(f"\nChi-square statistic: {chi2_lib_24h:.4f}")
        logger.info(f"p-value: {p_lib_24h:.6f}")
        logger.info(f"Degrees of freedom: {dof_lib_24h}")

        n_lib_24h = contingency_lib_24h_filtered.sum().sum()
        cramers_v_lib_24h = np.sqrt(chi2_lib_24h / (n_lib_24h * (min(contingency_lib_24h_filtered.shape) - 1)))
        logger.info(f"Cramer's V (effect size): {cramers_v_lib_24h:.4f}")

        results['location_24h_liberation'] = {
            'contingency_table': contingency_lib_24h_filtered.to_dict(),
            'chi2_stat': chi2_lib_24h,
            'p_value': p_lib_24h,
            'cramers_v': cramers_v_lib_24h,
            'n_total': int(n_lib_24h)
        }

    return results

location_outcomes_results = chi_square_location_outcomes()

# ============================================================================
# 3. SUMMARY AND EXPORT RESULTS
# ============================================================================

logger.info("\n" + "="*60)
logger.info("3. SUMMARY OF STATISTICAL RESULTS")
logger.info("="*60)

# Compile all results
statistical_results = {
    'site_name': 'MIMIC',  # Change this for other sites
    'analysis_date': pd.Timestamp.now().strftime('%Y-%m-%d'),
    'total_patients': len(enc),
    'liberation_vs_mortality': liberation_mortality_results,
    'location_vs_outcomes': location_outcomes_results
}

# Print summary
logger.info(f"\nSite: {statistical_results['site_name']}")
logger.info(f"Total patients: {statistical_results['total_patients']}")
logger.info(f"Analysis date: {statistical_results['analysis_date']}")

logger.info("\nKey Findings:")
if 'liberation_vs_mortality' in statistical_results:
    p_24h = statistical_results['liberation_vs_mortality']['24h']['p_value']
    p_72h = statistical_results['liberation_vs_mortality']['72h']['p_value']
    logger.info(f"- CRRT liberation vs mortality (24h): p = {p_24h:.6f}")
    logger.info(f"- CRRT liberation vs mortality (72h): p = {p_72h:.6f}")

if 'location_vs_outcomes' in statistical_results:
    if 'location_24h_mortality' in statistical_results['location_vs_outcomes']:
        p_loc_mort = statistical_results['location_vs_outcomes']['location_24h_mortality']['p_value']
        logger.info(f"- ICU location vs 24h mortality: p = {p_loc_mort:.6f}")

    if 'location_24h_liberation' in statistical_results['location_vs_outcomes']:
        p_loc_lib = statistical_results['location_vs_outcomes']['location_24h_liberation']['p_value']
        logger.info(f"- ICU location vs 24h liberation: p = {p_loc_lib:.6f}")

# Export results to JSON for multi-site comparison
output_file = f"../output/final/statistical_analysis_{statistical_results['site_name'].lower()}.json"
with open(output_file, 'w') as f:
    # Convert numpy types to native Python types for JSON serialization
    def convert_types(obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj

    # Deep convert the dictionary
    json_ready = json.loads(json.dumps(statistical_results, default=convert_types))
    json.dump(json_ready, f, indent=2)

logger.info(f"\nResults exported to: {output_file}")
logger.info(f"Log file saved to: {log_file}")

# Remove handlers to avoid duplicate logging in future cells
logger.handlers.clear()

In [None]:
# ============================================================================
# STATISTICAL ANALYSIS VISUALIZATIONS FOR MULTI-SITE COMPARISON
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Set style for consistent, publication-ready plots
plt.style.use('default')
sns.set_palette("husl")

# Create output directory
import os
os.makedirs("../output/final/graphs", exist_ok=True)

# Get site name from config
site_name = pyCLIF.helper['site_name']

print(f"Creating statistical analysis visualizations for {site_name}...")

# ============================================================================
# 1. CRRT LIBERATION vs MORTALITY VISUALIZATION (Dynamic)
# ============================================================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Extract data dynamically from statistical results
lib_mort_24h = liberation_mortality_results['24h']
lib_mort_72h = liberation_mortality_results['72h']

# Recreate contingency tables from the results
contingency_24h = pd.DataFrame(lib_mort_24h['contingency_table'])
contingency_72h = pd.DataFrame(lib_mort_72h['contingency_table'])

# Colors
colors = ['#2E8B57', '#DC143C']  # Green for survived, red for died

# 24-hour analysis
bars1 = ax1.bar(['On CRRT', 'Off CRRT'], contingency_24h.loc['Survived'],
                color=colors[0], label='Survived', alpha=0.8)
bars2 = ax1.bar(['On CRRT', 'Off CRRT'], contingency_24h.loc['Died'],
                bottom=contingency_24h.loc['Survived'], color=colors[1],
                label='Died', alpha=0.8)

# Add percentage labels
for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
    total = contingency_24h.iloc[:, i].sum()
    surv_pct = contingency_24h.iloc[0, i] / total * 100
    died_pct = contingency_24h.iloc[1, i] / total * 100
    
    # Survived percentage
    ax1.text(bar1.get_x() + bar1.get_width()/2, bar1.get_height()/2,
            f'{surv_pct:.1f}%\n(n={contingency_24h.iloc[0, i]})',
            ha='center', va='center', fontweight='bold', color='white')
    
    # Died percentage
    ax1.text(bar2.get_x() + bar2.get_width()/2,
            bar1.get_height() + bar2.get_height()/2,
            f'{died_pct:.1f}%\n(n={contingency_24h.iloc[1, i]})',
            ha='center', va='center', fontweight='bold', color='white')

p_val_24h = lib_mort_24h['p_value']
cramers_24h = lib_mort_24h['cramers_v']
ax1.set_title(f'24-Hour Analysis\np = {p_val_24h:.6f}, Cramer\'s V = {cramers_24h:.3f}',
            fontsize=12, fontweight='bold')
ax1.set_ylabel('Number of Patients')
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# 72-hour analysis
bars3 = ax2.bar(['On CRRT', 'Off CRRT'], contingency_72h.loc['Survived'],
                color=colors[0], label='Survived', alpha=0.8)
bars4 = ax2.bar(['On CRRT', 'Off CRRT'], contingency_72h.loc['Died'],
                bottom=contingency_72h.loc['Survived'], color=colors[1],
                label='Died', alpha=0.8)

# Add percentage labels for 72h
for i, (bar3, bar4) in enumerate(zip(bars3, bars4)):
    total = contingency_72h.iloc[:, i].sum()
    surv_pct = contingency_72h.iloc[0, i] / total * 100
    died_pct = contingency_72h.iloc[1, i] / total * 100
    
    # Survived percentage
    ax2.text(bar3.get_x() + bar3.get_width()/2, bar3.get_height()/2,
            f'{surv_pct:.1f}%\n(n={contingency_72h.iloc[0, i]})',
            ha='center', va='center', fontweight='bold', color='white')
    
    # Died percentage
    ax2.text(bar4.get_x() + bar4.get_width()/2,
            bar3.get_height() + bar4.get_height()/2,
            f'{died_pct:.1f}%\n(n={contingency_72h.iloc[1, i]})',
            ha='center', va='center', fontweight='bold', color='white')

p_val_72h = lib_mort_72h['p_value']
cramers_72h = lib_mort_72h['cramers_v']
ax2.set_title(f'72-Hour Analysis (24h Survivors Only)\np = {p_val_72h:.6f}, Cramer\'s V = {cramers_72h:.3f}',
            fontsize=12, fontweight='bold')
ax2.set_ylabel('Number of Patients')
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.suptitle(f'{site_name}: CRRT Liberation vs Mortality',
            fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('../output/final/graphs/crrt_liberation_vs_mortality.png',
            dpi=300, bbox_inches='tight')
plt.show()
plt.close()

# ============================================================================
# 2. ICU LOCATION OUTCOMES VISUALIZATION (Dynamic)
# ============================================================================

# Load location data dynamically from results
if 'location_vs_outcomes' in statistical_results and location_outcomes_results:
    
    # Get location info from the combined summary data
    combined_summary = pd.read_csv('../output/intermediate/combined_summary.csv')
    post72_df = combined_summary[combined_summary['window'] == 'Post-72h'].set_index('encounter_block')
    
    # Merge location info with our encounter-level data
    enc_with_location = enc.merge(
        post72_df[['location_type']],
        left_on='encounter_block',
        right_index=True,
        how='left'
    )
    
    # Calculate liberation rates by location
    location_liberation = pd.crosstab(
        enc_with_location['location_type'],
        ~enc_with_location['on_crrt_at_24h'],  # Liberation = NOT on CRRT
        margins=False
    )
    
    location_liberation.columns = ['Still_on_CRRT', 'Liberated']
    location_liberation['Total'] = location_liberation.sum(axis=1)
    location_liberation['Liberation_Rate'] = location_liberation['Liberated'] / location_liberation['Total'] * 100
    
    # Calculate mortality rates by location
    location_mortality = pd.crosstab(
        enc_with_location['location_type'],
        enc_with_location['died_within_24h'],
        margins=False
    )
    
    location_mortality.columns = ['Survived', 'Died']
    location_mortality['Total'] = location_mortality.sum(axis=1)
    location_mortality['Mortality_Rate'] = location_mortality['Died'] / location_mortality['Total'] * 100
    
    # Filter locations with sufficient sample size (n>=10)
    valid_locations = location_liberation.index[location_liberation['Total'] >= 10]
    location_liberation_filtered = location_liberation.loc[valid_locations]
    location_mortality_filtered = location_mortality.loc[valid_locations]
    
    if len(location_liberation_filtered) > 0:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # Liberation rates
        colors_lib = sns.color_palette("viridis", len(location_liberation_filtered))
        bars1 = ax1.bar(location_liberation_filtered.index, location_liberation_filtered['Liberation_Rate'],
                        color=colors_lib, alpha=0.8)
        
        # Add value labels on bars
        for bar, idx in zip(bars1, location_liberation_filtered.index):
            rate = location_liberation_filtered.loc[idx, 'Liberation_Rate']
            total = location_liberation_filtered.loc[idx, 'Total']
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{rate:.1f}%\n(n={total})', ha='center', va='bottom', fontweight='bold')
        
        # Get p-value from results
        lib_p_val = location_outcomes_results.get('location_24h_liberation', {}).get('p_value', 'N/A')
        if isinstance(lib_p_val, float):
            lib_p_str = f'p = {lib_p_val:.6f}'
        else:
            lib_p_str = 'p = N/A'
        
        ax1.set_title(f'24-Hour CRRT Liberation Rates by ICU Type\n{lib_p_str}',
                    fontsize=14, fontweight='bold')
        ax1.set_ylabel('Liberation Rate (%)')
        ax1.set_ylim(0, 100)
        ax1.grid(axis='y', alpha=0.3)
        ax1.tick_params(axis='x', rotation=45)
        
        # Mortality rates
        colors_mort = sns.color_palette("plasma", len(location_mortality_filtered))
        bars2 = ax2.bar(location_mortality_filtered.index, location_mortality_filtered['Mortality_Rate'],
                        color=colors_mort, alpha=0.8)
        
        # Add value labels on bars
        for bar, idx in zip(bars2, location_mortality_filtered.index):
            rate = location_mortality_filtered.loc[idx, 'Mortality_Rate']
            total = location_mortality_filtered.loc[idx, 'Total']
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
                    f'{rate:.1f}%\n(n={total})', ha='center', va='bottom', fontweight='bold')
        
        # Get p-value from results
        mort_p_val = location_outcomes_results.get('location_24h_mortality', {}).get('p_value', 'N/A')
        if isinstance(mort_p_val, float):
            mort_p_str = f'p = {mort_p_val:.6f}'
        else:
            mort_p_str = 'p = N/A'
        
        ax2.set_title(f'24-Hour Mortality Rates by ICU Type\n{mort_p_str}',
                    fontsize=14, fontweight='bold')
        ax2.set_ylabel('Mortality Rate (%)')
        max_mort_rate = location_mortality_filtered['Mortality_Rate'].max()
        ax2.set_ylim(0, max_mort_rate * 1.2)
        ax2.grid(axis='y', alpha=0.3)
        ax2.tick_params(axis='x', rotation=45)
        
        plt.suptitle(f'{site_name}: ICU Location vs CRRT Outcomes',
                    fontsize=16, fontweight='bold', y=0.98)
        
        # Add caption explaining liberation rate definition
        fig.text(0.5, 0.02, 'Liberation Rate Definition: Percentage of patients not on CRRT at 24 hours post-initiation',
                ha='center', fontsize=10, style='italic')
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.1)
        plt.savefig('../output/final/graphs/icu_location_outcomes.png',
                    dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()

# ============================================================================
# 3. SUMMARY DASHBOARD FOR MULTI-SITE COMPARISON (Dynamic)
# ============================================================================

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# Calculate metrics dynamically from enc dataframe
total_patients = len(enc)
mortality_24h = enc['died_within_24h'].sum()
mortality_72h = enc['died_within_72h'].sum()
liberation_24h = (~enc['on_crrt_at_24h']).sum()

metrics = {
    'Total Patients': total_patients,
    '24h Mortality': f"{mortality_24h/total_patients*100:.1f}%",
    '72h Mortality': f"{mortality_72h/total_patients*100:.1f}%",
    '24h Liberation': f"{liberation_24h/total_patients*100:.1f}%"
}

# Cohort flow chart data (dynamic)
survivors_24h = total_patients - mortality_24h
on_crrt_24h = enc['on_crrt_at_24h'].sum()
off_crrt_24h = liberation_24h
survivors_72h = total_patients - mortality_72h
on_crrt_72h = enc[~enc['died_within_72h']]['on_crrt_at_72h'].sum()
off_crrt_72h = survivors_72h - on_crrt_72h

cohort_flow = pd.DataFrame({
    'Time_Point': ['Baseline', '24 Hours', '72 Hours'],
    'Total': [total_patients, total_patients, survivors_24h],
    'Alive': [total_patients, survivors_24h, survivors_72h],
    'On_CRRT': [total_patients, on_crrt_24h, on_crrt_72h],
    'Off_CRRT': [0, off_crrt_24h, off_crrt_72h]
})

# Flow chart visualization
ax1.plot(cohort_flow['Time_Point'], cohort_flow['Alive'], 'o-', linewidth=3,
        markersize=8, label='Alive', color='green')
ax1.plot(cohort_flow['Time_Point'], cohort_flow['On_CRRT'], 's-', linewidth=3,
        markersize=8, label='On CRRT', color='blue')
ax1.plot(cohort_flow['Time_Point'], cohort_flow['Off_CRRT'], '^-', linewidth=3,
        markersize=8, label='Off CRRT', color='orange')

ax1.set_title(f'{site_name}: Patient Flow Over Time', fontsize=14, fontweight='bold')
ax1.set_ylabel('Number of Patients')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Effect sizes comparison (dynamic)
effects_data = []

# Liberation vs Mortality effects
effects_data.append(['Liberation vs\nMortality (24h)',
                    liberation_mortality_results['24h']['cramers_v'],
                    liberation_mortality_results['24h']['p_value']])

effects_data.append(['Liberation vs\nMortality (72h)',
                    liberation_mortality_results['72h']['cramers_v'],
                    liberation_mortality_results['72h']['p_value']])

# Location effects (if available)
if 'location_24h_mortality' in location_outcomes_results:
    effects_data.append(['Location vs\nMortality',
                        location_outcomes_results['location_24h_mortality']['cramers_v'],
                        location_outcomes_results['location_24h_mortality']['p_value']])

if 'location_24h_liberation' in location_outcomes_results:
    effects_data.append(['Location vs\nLiberation',
                        location_outcomes_results['location_24h_liberation']['cramers_v'],
                        location_outcomes_results['location_24h_liberation']['p_value']])

effects = pd.DataFrame(effects_data, columns=['Analysis', 'Cramers_V', 'P_Value'])

# Significance levels
effects['Significant'] = effects['P_Value'].apply(
    lambda p: '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'NS'
)

colors_effect = ['red' if p > 0.05 else 'darkgreen' for p in effects['P_Value']]
bars = ax2.bar(effects['Analysis'], effects['Cramers_V'], color=colors_effect, alpha=0.7)

for bar, cramers, sig in zip(bars, effects['Cramers_V'], effects['Significant']):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{cramers:.3f}\n{sig}', ha='center', va='bottom', fontweight='bold')

ax2.set_title(f'{site_name}: Effect Sizes (Cramer\'s V)', fontsize=14, fontweight='bold')
ax2.set_ylabel('Effect Size')
max_effect = effects['Cramers_V'].max()
ax2.set_ylim(0, max_effect * 1.2)
ax2.tick_params(axis='x', rotation=0)
ax2.grid(axis='y', alpha=0.3)

# Liberation rate comparison (if location data available)
if 'location_liberation_filtered' in locals() and len(location_liberation_filtered) > 0:
    lib_comparison = location_liberation_filtered.sort_values('Liberation_Rate', ascending=True)
    bars3 = ax3.barh(lib_comparison.index, lib_comparison['Liberation_Rate'],
                    color=sns.color_palette("viridis", len(lib_comparison)))
    
    for i, idx in enumerate(lib_comparison.index):
        rate = lib_comparison.loc[idx, 'Liberation_Rate']
        total = lib_comparison.loc[idx, 'Total']
        ax3.text(rate + 1, i, f'{rate:.1f}% (n={total})', va='center', fontweight='bold')
    
    # Get p-value
    lib_p_val = location_outcomes_results.get('location_24h_liberation', {}).get('p_value', 'N/A')
    if isinstance(lib_p_val, float):
        lib_p_str = f'p = {lib_p_val:.6f}'
    else:
        lib_p_str = 'p = N/A'
    
    ax3.set_title(f'{site_name}: 24h Liberation Rates by ICU\n{lib_p_str}', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Liberation Rate (%)')
    ax3.set_xlim(0, 100)
    ax3.grid(axis='x', alpha=0.3)
else:
    ax3.text(0.5, 0.5, 'Location data\nnot available', ha='center', va='center',
            transform=ax3.transAxes, fontsize=14)
    ax3.set_title(f'{site_name}: Liberation Rates by ICU', fontsize=14, fontweight='bold')

# Key statistics table (dynamic)
stats_text = f"""
{site_name} CRRT EPIDEMIOLOGY SUMMARY

Total Patients: {metrics['Total Patients']:,}
24h Mortality: {metrics['24h Mortality']} (n={mortality_24h})
72h Mortality: {metrics['72h Mortality']} (n={mortality_72h})
24h Liberation: {metrics['24h Liberation']} (n={liberation_24h})

STATISTICAL TESTS:
Liberation vs Mortality (24h): p = {liberation_mortality_results['24h']['p_value']:.6f}
Liberation vs Mortality (72h): p = {liberation_mortality_results['72h']['p_value']:.6f}
"""

# Add location results if available
if 'location_24h_mortality' in location_outcomes_results:
    stats_text += f"Location vs Mortality: p = {location_outcomes_results['location_24h_mortality']['p_value']:.6f}\n"
if 'location_24h_liberation' in location_outcomes_results:
    stats_text += f"Location vs Liberation: p = {location_outcomes_results['location_24h_liberation']['p_value']:.6f}\n"

ax4.text(0.1, 0.5, stats_text, transform=ax4.transAxes, fontsize=11,
        verticalalignment='center', bbox=dict(boxstyle="round,pad=0.3",
        facecolor="lightblue", alpha=0.5))
ax4.set_xlim(0, 1)
ax4.set_ylim(0, 1)
ax4.axis('off')

plt.suptitle(f'{site_name}: CRRT Epidemiology Statistical Analysis Dashboard',
            fontsize=18, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('../output/final/graphs/statistical_analysis_dashboard.png',
            dpi=300, bbox_inches='tight')
plt.show()
plt.close()

## summary of results for a quick view 
def create_site_comparison_template():
    """Template function for comparing results across sites"""
    
    template_data = {
        'Site': site_name,
        'N_Patients': total_patients,
        'Mortality_24h_pct': mortality_24h/total_patients*100,
        'Mortality_72h_pct': mortality_72h/total_patients*100,
        'Liberation_24h_pct': liberation_24h/total_patients*100,
        'Liberation_vs_Mortality_24h_p': liberation_mortality_results['24h']['p_value'],
        'Liberation_vs_Mortality_72h_p': liberation_mortality_results['72h']['p_value'],
        'Cramers_V_Liberation_24h': liberation_mortality_results['24h']['cramers_v'],
        'Cramers_V_Liberation_72h': liberation_mortality_results['72h']['cramers_v']
    }
    
    # Add location results if available
    if 'location_24h_mortality' in location_outcomes_results:
        template_data['Location_vs_Mortality_p'] = location_outcomes_results['location_24h_mortality']['p_value']
        template_data['Cramers_V_Location_Mortality'] = location_outcomes_results['location_24h_mortality']['cramers_v']
    
    if 'location_24h_liberation' in location_outcomes_results:
        template_data['Location_vs_Liberation_p'] = location_outcomes_results['location_24h_liberation']['p_value']
        template_data['Cramers_V_Location_Liberation'] = location_outcomes_results['location_24h_liberation']['cramers_v']
    
    # Save as CSV for easy comparison
    template_df = pd.DataFrame([template_data])
    template_df.to_csv(f'../output/final/{site_name.lower()}_statistical_summary.csv', index=False)
    
    return template_df

template_df = create_site_comparison_template()

# (H) Kaplan-Meier Survival Analysis

Variable definitions: 
* Time 0: CRRT start time
* Liberation event: Stopped CRRT while alive, so death is treated as a competing risk. 
* Using death_dttm_proxy as a death indicator which is discharge dttm when discharge category Expired or Hospice
* For survival analysis, follow up till 28 days (max_followup cap)
* Censoring Rules:
    1. Liberation Analysis Censoring
        - Event (1): Patient stopped CRRT while alive
        - Censored (0): Patient died while on CRRT
    2. Survival Analysis Censoring
        - Event (1): Patient died
        - Censored (0): Patient survived to end of follow-up (28 days)
* Landmark times (24h and 72h)
    1. Only include patients who survived to landmark time: `landmark_df = df[df['followup_death'] >= landmark_time]`
    2. Reset time origin to landmark time: `landmark_df['followup_from_landmark'] = landmark_df['followup_death'] - landmark_time`
* Valid locations are defined: A location needs at least 10 patients to be considered for analysis. 


Limitations:
* Not accounting for changes in CRRT course i.e. not handling for time varying covariates
* there is no adjustment for baseline differences between ICUs- doesn't handle for confounding
* there could be selection bias because ICU admission is not random
* administrative censoring because 28 day cap may miss late events. 

In [None]:
# ============================================================================
# SECTION H: Kaplan-Meier Survival Analysis
# ============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines.plotting import plot_lifetimes
import warnings
import logging
import sys
warnings.filterwarnings('ignore')

# Set up logger
logger = logging.getLogger('km_analysis')
logger.setLevel(logging.INFO)

# Create file handler
fh = logging.FileHandler('../output/final/km_analysis.log')
fh.setLevel(logging.INFO)

# Create console handler
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter('%(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

# Add handlers to logger
logger.addHandler(fh)
logger.addHandler(ch)

logger.info("=== KAPLAN-MEIER SURVIVAL ANALYSIS ===")

# Get site name from config
site_name = pyCLIF.helper['site_name']

# ============================================================================
# 1. CREATE SURVIVAL DATASET
# ============================================================================

def create_survival_dataset():
    """Create time-to-event dataset for survival analysis"""

    logger.info("Creating survival dataset...")

    # Start with encounter-level data
    survival_df = enc.copy()

    # Calculate time to events (in hours)
    survival_df['time_to_death'] = np.where(
        survival_df['death_dttm_proxy'].notna(),
        (survival_df['death_dttm_proxy'] - survival_df['first_crrt_time']).dt.total_seconds() / 3600,
        np.nan
    )

    survival_df['time_to_liberation'] = (
        survival_df['end_crrt_time'] - survival_df['first_crrt_time']
    ).dt.total_seconds() / 3600

    # Create event indicators
    survival_df['death_event'] = survival_df['death_dttm_proxy'].notna().astype(int)

    # Liberation event: stopped CRRT while alive
    # If patient died, they cannot be "liberated" - this is a competing risk
    survival_df['liberation_event'] = np.where(
        survival_df['death_event'] == 1,
        0,  # Dead patients cannot be liberated
        1   # Alive patients who stopped CRRT are liberated
    )

    # Get location data
    combined_summary = pd.read_csv('../output/intermediate/combined_summary.csv')
    post72_df = combined_summary[combined_summary['window'] == 'Post-72h'].set_index('encounter_block')

    # Merge location info
    survival_df = survival_df.merge(
        post72_df[['location_type']],
        left_on='encounter_block',
        right_index=True,
        how='left'
    )

    # Get baseline SOFA scores for risk adjustment
    pre24_df = combined_summary[combined_summary['window'] == 'Pre-24h'].set_index('encounter_block')
    survival_df = survival_df.merge(
        pre24_df[['sofa_total']],
        left_on='encounter_block',
        right_index=True,
        how='left',
        suffixes=('', '_baseline')
    )

    # Create follow-up time for each outcome
    # For liberation: follow until death or liberation (whichever comes first)
    survival_df['followup_liberation'] = np.where(
        survival_df['death_event'] == 1,
        survival_df['time_to_death'],      # Follow until death
        survival_df['time_to_liberation']   # Follow until liberation
    )

    # For mortality: follow until death or end of observation
    max_followup = 28 * 24  # 28 days in hours
    survival_df['followup_death'] = np.where(
        survival_df['death_event'] == 1,
        survival_df['time_to_death'],
        max_followup  # Censored at 28 days
    )

    # Cap follow-up times at reasonable limits
    survival_df['followup_liberation'] = np.minimum(survival_df['followup_liberation'], max_followup)
    survival_df['followup_death'] = np.minimum(survival_df['followup_death'], max_followup)

    # Filter out invalid times
    survival_df = survival_df[
        (survival_df['followup_liberation'] > 0) &
        (survival_df['followup_death'] > 0)
    ].copy()

    logger.info(f"Created survival dataset with {len(survival_df)} patients")
    logger.info(f"Liberation events: {survival_df['liberation_event'].sum()}")
    logger.info(f"Death events: {survival_df['death_event'].sum()}")

    return survival_df

survival_data = create_survival_dataset()

# ============================================================================
# 2. KAPLAN-MEIER ANALYSIS: TIME TO CRRT LIBERATION
# ============================================================================

logger.info("\n" + "="*60)
logger.info("2. TIME TO CRRT LIBERATION ANALYSIS")
logger.info("="*60)

def km_liberation_analysis(df):
    """Kaplan-Meier analysis for time to CRRT liberation"""

    results = {}

    # Overall liberation curve
    kmf_overall = KaplanMeierFitter()
    kmf_overall.fit(
        durations=df['followup_liberation'],
        event_observed=df['liberation_event'],
        label='Overall'
    )

    results['overall'] = kmf_overall

    # Liberation by ICU location
    valid_locations = df['location_type'].value_counts()
    valid_locations = valid_locations[valid_locations >= 10].index  # Minimum 10 patients

    location_results = {}

    fig, ax = plt.subplots(figsize=(12, 8))

    colors = sns.color_palette("husl", len(valid_locations))

    for i, location in enumerate(valid_locations):
        location_data = df[df['location_type'] == location]

        if len(location_data) > 5:  # Minimum for meaningful analysis
            kmf_loc = KaplanMeierFitter()
            kmf_loc.fit(
                durations=location_data['followup_liberation'],
                event_observed=location_data['liberation_event'],
                label=f'{location} (n={len(location_data)})'
            )

            location_results[location] = kmf_loc
            kmf_loc.plot_survival_function(ax=ax, color=colors[i])

    ax.set_title(f'{site_name}: Time to CRRT Liberation by ICU Location',
                fontsize=14, fontweight='bold')
    ax.set_xlabel('Time from CRRT Initiation (hours)')
    ax.set_ylabel('Probability of Remaining on CRRT')
    ax.grid(True, alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.savefig('../output/final/graphs/km_liberation_by_location.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    results['by_location'] = location_results

    # Log-rank test for differences between locations
    if len(location_results) > 1:
        location_list = list(location_results.keys())
        durations_list = []
        events_list = []
        groups_list = []

        for loc in location_list:
            loc_data = df[df['location_type'] == loc]
            durations_list.extend(loc_data['followup_liberation'].tolist())
            events_list.extend(loc_data['liberation_event'].tolist())
            groups_list.extend([loc] * len(loc_data))

        try:
            logrank_result = multivariate_logrank_test(
                durations_list, groups_list, events_list
            )

            logger.info(f"\nLog-rank test for liberation differences across ICU locations:")
            logger.info(f"Chi-square statistic: {logrank_result.test_statistic:.4f}")
            logger.info(f"p-value: {logrank_result.p_value:.6f}")

            results['logrank_liberation'] = {
                'chi2': logrank_result.test_statistic,
                'p_value': logrank_result.p_value
            }

        except Exception as e:
            logger.info(f"Could not perform log-rank test: {e}")

    return results

liberation_results = km_liberation_analysis(survival_data)

# ============================================================================  
# 3. KAPLAN-MEIER ANALYSIS: SURVIVAL (TIME TO DEATH)
# ============================================================================

logger.info("\n" + "="*60)
logger.info("3. SURVIVAL ANALYSIS")
logger.info("="*60)

def km_survival_analysis(df):
    """Kaplan-Meier analysis for survival"""

    results = {}

    # Overall survival curve
    kmf_overall = KaplanMeierFitter()
    kmf_overall.fit(
        durations=df['followup_death'],
        event_observed=df['death_event'],
        label='Overall'
    )

    results['overall'] = kmf_overall

    # Survival by ICU location
    valid_locations = df['location_type'].value_counts()
    valid_locations = valid_locations[valid_locations >= 10].index

    location_results = {}

    fig, ax = plt.subplots(figsize=(12, 8))

    colors = sns.color_palette("husl", len(valid_locations))

    for i, location in enumerate(valid_locations):
        location_data = df[df['location_type'] == location]

        if len(location_data) > 5:
            kmf_loc = KaplanMeierFitter()
            kmf_loc.fit(
                durations=location_data['followup_death'],
                event_observed=location_data['death_event'],
                label=f'{location} (n={len(location_data)})'
            )

            location_results[location] = kmf_loc
            kmf_loc.plot_survival_function(ax=ax, color=colors[i])

    ax.set_title(f'{site_name}: Survival by ICU Location',
                fontsize=14, fontweight='bold')
    ax.set_xlabel('Time from CRRT Initiation (hours)')
    ax.set_ylabel('Survival Probability')
    ax.grid(True, alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.savefig('../output/final/graphs/km_survival_by_location.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    results['by_location'] = location_results

    # Log-rank test for survival differences
    if len(location_results) > 1:
        location_list = list(location_results.keys())
        durations_list = []
        events_list = []
        groups_list = []

        for loc in location_list:
            loc_data = df[df['location_type'] == loc]
            durations_list.extend(loc_data['followup_death'].tolist())
            events_list.extend(loc_data['death_event'].tolist())
            groups_list.extend([loc] * len(loc_data))

        try:
            logrank_result = multivariate_logrank_test(
                durations_list, groups_list, events_list
            )

            logger.info(f"\nLog-rank test for survival differences across ICU locations:")
            logger.info(f"Chi-square statistic: {logrank_result.test_statistic:.4f}")
            logger.info(f"p-value: {logrank_result.p_value:.6f}")

            results['logrank_survival'] = {
                'chi2': logrank_result.test_statistic,
                'p_value': logrank_result.p_value
            }

        except Exception as e:
            logger.info(f"Could not perform log-rank test: {e}")

    return results

survival_results = km_survival_analysis(survival_data)

# ============================================================================
# 4. LANDMARK ANALYSIS: SURVIVAL BY EARLY LIBERATION STATUS
# ============================================================================

logger.info("\n" + "="*60)
logger.info("4. LANDMARK ANALYSIS: EARLY LIBERATION vs SURVIVAL")
logger.info("="*60)

def landmark_analysis(df, landmark_time=24):
    """Landmark analysis: survival by liberation status at landmark time"""

    logger.info(f"Landmark analysis at {landmark_time} hours...")

    # Only include patients who survived to landmark time
    landmark_df = df[df['followup_death'] >= landmark_time].copy()

    logger.info(f"Patients surviving to {landmark_time}h landmark: {len(landmark_df)}")

    # Determine liberation status at landmark time
    landmark_df['liberated_at_landmark'] = (
        (landmark_df['followup_liberation'] <= landmark_time) &
        (landmark_df['liberation_event'] == 1)
    )

    # Adjust follow-up time from landmark
    landmark_df['followup_from_landmark'] = landmark_df['followup_death'] - landmark_time

    # Death events remain the same (but follow-up time is adjusted)

    lib_counts = landmark_df['liberated_at_landmark'].value_counts()
    logger.info(f"Liberation status at {landmark_time}h:")
    logger.info(f"  Still on CRRT: {lib_counts.get(False, 0)}")
    logger.info(f"  Liberated: {lib_counts.get(True, 0)}")

    # Kaplan-Meier by liberation status
    fig, ax = plt.subplots(figsize=(10, 6))

    for lib_status in [False, True]:
        status_data = landmark_df[landmark_df['liberated_at_landmark'] == lib_status]

        if len(status_data) > 5:
            label = f"Liberated by {landmark_time}h (n={len(status_data)})" if lib_status else f"Still on CRRT at {landmark_time}h (n={len(status_data)})"

            kmf = KaplanMeierFitter()
            kmf.fit(
                durations=status_data['followup_from_landmark'],
                event_observed=status_data['death_event'],
                label=label
            )

            color = 'green' if lib_status else 'red'
            kmf.plot_survival_function(ax=ax, color=color)

    ax.set_title(f'{site_name}: Survival from {landmark_time}h by Early Liberation Status',
                fontsize=14, fontweight='bold')
    ax.set_xlabel(f'Time from {landmark_time}h Landmark (hours)')
    ax.set_ylabel('Survival Probability')
    ax.grid(True, alpha=0.3)
    ax.legend()

    plt.figtext(0.5, 0.02, f'Landmark Analysis: Only includes patients who survived to {landmark_time} hours',
                ha='center', fontsize=10, style='italic')

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1)
    plt.savefig(f'../output/final/graphs/landmark_analysis_{landmark_time}h.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    # Log-rank test for landmark analysis
    if len(landmark_df[landmark_df['liberated_at_landmark'] == True]) > 5 and \
        len(landmark_df[landmark_df['liberated_at_landmark'] == False]) > 5:

        liberated_data = landmark_df[landmark_df['liberated_at_landmark'] == True]
        not_liberated_data = landmark_df[landmark_df['liberated_at_landmark'] == False]

        try:
            logrank_result = logrank_test(
                liberated_data['followup_from_landmark'],
                not_liberated_data['followup_from_landmark'],
                liberated_data['death_event'],
                not_liberated_data['death_event']
            )

            logger.info(f"\nLog-rank test for survival difference by {landmark_time}h liberation status:")
            logger.info(f"Chi-square statistic: {logrank_result.test_statistic:.4f}")
            logger.info(f"p-value: {logrank_result.p_value:.6f}")

            return {
                'landmark_time': landmark_time,
                'n_landmark_survivors': len(landmark_df),
                'n_liberated': len(liberated_data),
                'n_not_liberated': len(not_liberated_data),
                'logrank_chi2': logrank_result.test_statistic,
                'logrank_p': logrank_result.p_value
            }

        except Exception as e:
            logger.info(f"Could not perform log-rank test: {e}")
            return None

    return None

# Perform landmark analyses at 24h and 72h
landmark_24h = landmark_analysis(survival_data, landmark_time=24)
landmark_72h = landmark_analysis(survival_data, landmark_time=72)

# ============================================================================
# 5. SUMMARY OF SURVIVAL ANALYSIS RESULTS
# ============================================================================

logger.info("\n" + "="*60)
logger.info("5. SURVIVAL ANALYSIS SUMMARY")
logger.info("="*60)

# Compile results for export
km_results = {
    'site_name': site_name,
    'analysis_date': pd.Timestamp.now().strftime('%Y-%m-%d'),
    'n_patients': len(survival_data),
    'n_liberation_events': int(survival_data['liberation_event'].sum()),
    'n_death_events': int(survival_data['death_event'].sum()),
    'liberation_analysis': {},
    'survival_analysis': {},
    'landmark_analysis': {}
}

# Add liberation results
if 'logrank_liberation' in liberation_results:
    km_results['liberation_analysis'] = {
        'logrank_chi2': liberation_results['logrank_liberation']['chi2'],
        'logrank_p': liberation_results['logrank_liberation']['p_value']
    }

# Add survival results  
if 'logrank_survival' in survival_results:
    km_results['survival_analysis'] = {
        'logrank_chi2': survival_results['logrank_survival']['chi2'],
        'logrank_p': survival_results['logrank_survival']['p_value']
    }

# Add landmark results
if landmark_24h:
    km_results['landmark_analysis']['24h'] = landmark_24h

if landmark_72h:
    km_results['landmark_analysis']['72h'] = landmark_72h

# Export results
import json
with open(f'../output/final/km_analysis_{site_name.lower()}.json', 'w') as f:
    def convert_types(obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj

    json_ready = json.loads(json.dumps(km_results, default=convert_types))
    json.dump(json_ready, f, indent=2)

# Print summary
logger.info(f"\nKaplan-Meier Analysis Summary for {site_name}:")
logger.info(f"Total patients analyzed: {len(survival_data)}")
logger.info(f"Liberation events: {survival_data['liberation_event'].sum()}")
logger.info(f"Death events: {survival_data['death_event'].sum()}")

if 'logrank_liberation' in liberation_results:
    p_lib = liberation_results['logrank_liberation']['p_value']
    logger.info(f"Liberation differences by ICU location: p = {p_lib:.6f}")

if 'logrank_survival' in survival_results:
    p_surv = survival_results['logrank_survival']['p_value']
    logger.info(f"Survival differences by ICU location: p = {p_surv:.6f}")

if landmark_24h:
    p_land = landmark_24h['logrank_p']
    logger.info(f"24h landmark analysis (early liberation vs survival): p = {p_land:.6f}")

if landmark_72h:
    p_land = landmark_72h['logrank_p']
    logger.info(f"72h landmark analysis (early liberation vs survival): p = {p_land:.6f}")

logger.info(f"\nFiles created:")
logger.info(f"  - km_liberation_by_location.png")
logger.info(f"  - km_survival_by_location.png")
logger.info(f"  - landmark_analysis_24h.png")
logger.info(f"  - landmark_analysis_72h.png")
logger.info(f"  - km_analysis_{site_name.lower()}.json")

logger.info("\n" + "="*60)
logger.info("KAPLAN-MEIER ANALYSIS COMPLETE")
logger.info("="*60)

# Create combined dashboard figure
plt.figure(figsize=(20, 15))

# Liberation curve
plt.subplot(2, 2, 1)
for location, kmf in liberation_results['by_location'].items():
    kmf.plot_survival_function(label=f'{location}')
plt.title('Time to CRRT Liberation by ICU Location')
plt.xlabel('Time from CRRT Initiation (hours)')
plt.ylabel('Probability of Remaining on CRRT')
plt.grid(True, alpha=0.3)
plt.legend()

# Survival curve
plt.subplot(2, 2, 2)
for location, kmf in survival_results['by_location'].items():
    kmf.plot_survival_function(label=f'{location}')
plt.title('Survival by ICU Location')
plt.xlabel('Time from CRRT Initiation (hours)')
plt.ylabel('Survival Probability')
plt.grid(True, alpha=0.3)
plt.legend()

# 24h landmark analysis
plt.subplot(2, 2, 3)
landmark_df_24 = survival_data[survival_data['followup_death'] >= 24].copy()
landmark_df_24['liberated_at_landmark'] = (
    (landmark_df_24['followup_liberation'] <= 24) &
    (landmark_df_24['liberation_event'] == 1)
)
landmark_df_24['followup_from_landmark'] = landmark_df_24['followup_death'] - 24

for lib_status in [False, True]:
    status_data = landmark_df_24[landmark_df_24['liberated_at_landmark'] == lib_status]
    if len(status_data) > 5:
        kmf = KaplanMeierFitter()
        kmf.fit(
            durations=status_data['followup_from_landmark'],
            event_observed=status_data['death_event'],
            label=f"{'Liberated' if lib_status else 'Still on CRRT'} at 24h"
        )
        kmf.plot_survival_function()
plt.title('24h Landmark Analysis')
plt.xlabel('Time from 24h (hours)')
plt.ylabel('Survival Probability')
plt.grid(True, alpha=0.3)
plt.legend()

# 72h landmark analysis
plt.subplot(2, 2, 4)
landmark_df_72 = survival_data[survival_data['followup_death'] >= 72].copy()
landmark_df_72['liberated_at_landmark'] = (
    (landmark_df_72['followup_liberation'] <= 72) &
    (landmark_df_72['liberation_event'] == 1)
)
landmark_df_72['followup_from_landmark'] = landmark_df_72['followup_death'] - 72

for lib_status in [False, True]:
    status_data = landmark_df_72[landmark_df_72['liberated_at_landmark'] == lib_status]
    if len(status_data) > 5:
        kmf = KaplanMeierFitter()
        kmf.fit(
            durations=status_data['followup_from_landmark'],
            event_observed=status_data['death_event'],
            label=f"{'Liberated' if lib_status else 'Still on CRRT'} at 72h"
        )
        kmf.plot_survival_function()
plt.title('72h Landmark Analysis')
plt.xlabel('Time from 72h (hours)')
plt.ylabel('Survival Probability')
plt.grid(True, alpha=0.3)
plt.legend()

plt.suptitle(f'{site_name}: Survival Analysis Results', fontsize=16)
plt.tight_layout()
plt.savefig('../output/final/graphs/survival_analysis_all.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()


In [None]:
# ============================================================================
# EXPORT SURVIVAL CURVE DATA FOR MULTI-SITE COMPARISON
# ============================================================================

def export_survival_curves_for_comparison():
    """Export survival curve coordinates for multi-site plotting"""

    curve_data = {}

    # 1. Liberation curves by ICU
    liberation_curves = {}
    for location, kmf in liberation_results['by_location'].items():
        liberation_curves[location] = {
            'timeline': kmf.timeline.tolist(),
            'survival_function': kmf.survival_function_.iloc[:, 0].tolist(),
            'n_patients': int(kmf.durations.shape[0]),
            'n_events': int(kmf.event_observed.sum())
        }

    curve_data['liberation_by_icu'] = liberation_curves

    # 2. Survival curves by ICU  
    survival_curves = {}
    for location, kmf in survival_results['by_location'].items():
        survival_curves[location] = {
            'timeline': kmf.timeline.tolist(),
            'survival_function': kmf.survival_function_.iloc[:, 0].tolist(),
            'n_patients': int(kmf.durations.shape[0]),
            'n_events': int(kmf.event_observed.sum())
        }

    curve_data['survival_by_icu'] = survival_curves

    # 3. Landmark analysis curves (if available)
    if landmark_72h:
        # Get the landmark KM curves
        landmark_df = survival_data[survival_data['followup_death'] >= 72].copy()
        landmark_df['liberated_at_landmark'] = (
            (landmark_df['followup_liberation'] <= 72) &
            (landmark_df['liberation_event'] == 1)
        )
        landmark_df['followup_from_landmark'] = landmark_df['followup_death'] - 72

        landmark_curves = {}

        for lib_status in [False, True]:
            status_data = landmark_df[landmark_df['liberated_at_landmark'] == lib_status]

            if len(status_data) > 5:
                kmf = KaplanMeierFitter()
                kmf.fit(
                    durations=status_data['followup_from_landmark'],
                    event_observed=status_data['death_event']
                )

                status_label = ('liberated_by_72h' if lib_status 
                              else 'still_on_crrt_at_72h')
                landmark_curves[status_label] = {
                    'timeline': kmf.timeline.tolist(),
                    'survival_function': kmf.survival_function_.iloc[:, 0].tolist(),
                    'n_patients': int(len(status_data)),
                    'n_events': int(status_data['death_event'].sum())
                }

        curve_data['landmark_72h'] = landmark_curves

    # Save curve data
    curve_data['site_name'] = site_name
    curve_data['analysis_date'] = pd.Timestamp.now().strftime('%Y-%m-%d')

    with open(f'../output/final/km_curves_{site_name.lower()}.json', 'w') as f:
        json.dump(curve_data, f, indent=2)

    print(f"✅ Exported survival curve data to km_curves_{site_name.lower()}.json")

    return curve_data

# Export the curves
curve_export = export_survival_curves_for_comparison()

# ============================================================================
# SUMMARY TABLE FOR MULTI-SITE COMPARISON
# ============================================================================

def create_km_summary_table():
    """Create summary table of key KM metrics for comparison"""

    # Calculate median survival times
    icu_summaries = []

    for location, kmf in liberation_results['by_location'].items():
        try:
            median_liberation = kmf.median_survival_time_
            median_liberation = (median_liberation 
                               if not pd.isna(median_liberation) else None)
        except:
            median_liberation = None

        liberation_rate_24h = (1 - kmf.survival_function_at_times(24).iloc[0] 
                             if 24 in kmf.timeline else None)
        liberation_rate_72h = (1 - kmf.survival_function_at_times(72).iloc[0]
                             if 72 in kmf.timeline else None)

        icu_summaries.append({
            'site': site_name,
            'icu_type': location,
            'n_patients': int(kmf.durations.shape[0]),
            'n_liberation_events': int(kmf.event_observed.sum()),
            'median_liberation_time_hrs': median_liberation,
            'liberation_rate_24h': liberation_rate_24h,
            'liberation_rate_72h': liberation_rate_72h
        })

    # Add survival data
    for i, (location, kmf) in enumerate(survival_results['by_location'].items()):
        try:
            median_survival = kmf.median_survival_time_
            median_survival = (median_survival 
                             if not pd.isna(median_survival) else None)
        except:
            median_survival = None

        mortality_24h = (1 - kmf.survival_function_at_times(24).iloc[0]
                        if 24 in kmf.timeline else None)
        mortality_72h = (1 - kmf.survival_function_at_times(72).iloc[0]
                        if 72 in kmf.timeline else None)

        icu_summaries[i].update({
            'n_death_events': int(kmf.event_observed.sum()),
            'median_survival_time_hrs': median_survival,
            'mortality_24h': mortality_24h,
            'mortality_72h': mortality_72h
        })

    summary_df = pd.DataFrame(icu_summaries)
    summary_df.to_csv(f'../output/final/km_summary_{site_name.lower()}.csv', index=False)

    print(f"✅ Created KM summary table: km_summary_{site_name.lower()}.csv")
    print("\nSample of summary data:")
    print(summary_df[['icu_type', 'n_patients', 'liberation_rate_24h',
                     'mortality_24h']].round(3))

    return summary_df

km_summary = create_km_summary_table()

In [None]:
# Add this diagnostic code to understand the exclusions
import logging

# Configure logging
logging.basicConfig(
    filename='../output/final/km_analysis.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

logger = logging.getLogger(__name__)
logger.info("=== DIAGNOSTIC ===")

# Start with original enc dataframe 
logger.info(f"Original enc dataframe: {len(enc)} patients")

# Step 1: Check time calculations
enc_with_times = enc.copy()

# Calculate time to events (same as in function)
enc_with_times['time_to_death'] = np.where(
    enc_with_times['death_dttm_proxy'].notna(),
    (enc_with_times['death_dttm_proxy'] - enc_with_times['first_crrt_time']).dt.total_seconds() / 3600,
    np.nan
)

enc_with_times['time_to_liberation'] = (
    enc_with_times['end_crrt_time'] - enc_with_times['first_crrt_time']
).dt.total_seconds() / 3600

# Check for negative or zero liberation times
bad_liberation_times = (enc_with_times['time_to_liberation'] <= 0)
logger.info(f"Patients with bad liberation times (≤0): {bad_liberation_times.sum()}")

if bad_liberation_times.sum() > 0:
    logger.info("Sample of bad liberation times: check enc_with_times[bad_liberation_times][['encounter_block', 'time_to_liberation', 'first_crrt_time', 'end_crrt_time']].head()")
    # logger.info(enc_with_times[bad_liberation_times][['encounter_block', 'time_to_liberation', 'first_crrt_time', 'end_crrt_time']].head())

# Create follow-up times
enc_with_times['death_event'] = enc_with_times['death_dttm_proxy'].notna().astype(int)

enc_with_times['followup_liberation'] = np.where(
    enc_with_times['death_event'] == 1,
    enc_with_times['time_to_death'],
    enc_with_times['time_to_liberation']
)

max_followup = 28 * 24
enc_with_times['followup_death'] = np.where(
    enc_with_times['death_event'] == 1,
    enc_with_times['time_to_death'],
    max_followup
)

# Cap follow-up times
enc_with_times['followup_liberation'] = np.minimum(enc_with_times['followup_liberation'], max_followup)
enc_with_times['followup_death'] = np.minimum(enc_with_times['followup_death'], max_followup)

# Check for invalid follow-up times
bad_followup_lib = (enc_with_times['followup_liberation'] <= 0)
bad_followup_death = (enc_with_times['followup_death'] <= 0)

logger.info(f"Patients with bad liberation follow-up (≤0): {bad_followup_lib.sum()}")
logger.info(f"Patients with bad death follow-up (≤0): {bad_followup_death.sum()}")

# Check location merge
combined_summary = pd.read_csv('../output/intermediate/combined_summary.csv')
post72_df = combined_summary[combined_summary['window'] == 'Post-72h'].set_index('encounter_block')

logger.info(f"Patients in post72 location data: {len(post72_df)}")
logger.info(f"Patients in enc: {len(enc)}")

# Check merge success
merge_result = enc.merge(post72_df[['location_type']], left_on='encounter_block', right_index=True, how='left')
logger.info(f"Patients after location merge: {len(merge_result)}")

# Final filter
final_valid = (
    (enc_with_times['followup_liberation'] > 0) &
    (enc_with_times['followup_death'] > 0)
)

logger.info(f"Patients passing final filter: {final_valid.sum()}")
logger.info(f"Patients excluded by final filter: {(~final_valid).sum()}")

# Show which patients are excluded
if (~final_valid).sum() > 0:
    excluded = enc_with_times[~final_valid]
    logger.info("\nSample of excluded patients: check excluded[['encounter_block', 'followup_liberation', 'followup_death', 'time_to_liberation', 'time_to_death']].head() ")
    # logger.info(excluded[['encounter_block', 'followup_liberation', 'followup_death', 'time_to_liberation', 'time_to_death']].head())

    # Check if these are data errors
    logger.info("\nBreakdown of exclusion reasons:")
    logger.info(f"- Bad liberation follow-up: {(excluded['followup_liberation'] <= 0).sum()}")
    logger.info(f"- Bad death follow-up: {(excluded['followup_death'] <= 0).sum()}")

# (I) Final HTML Results 

In [None]:
# ============================================================================
# HTML TABLE GENERATION FUNCTIONS
# ============================================================================

def create_html_table_with_styling(df, title, subtitle="", highlight_significant=False, p_value_col=None):
    """Create a nicely formatted HTML table with custom styling"""

    # Create HTML content
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>{title}</title>
        <style>
            body {{
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f5f5f5;
            }}
            h1 {{
                color: #333;
                text-align: center;
                margin-bottom: 5px;
            }}
            h3 {{
                color: #666;
                text-align: center;
                margin-top: 5px;
                font-weight: normal;
            }}
            table {{
                border-collapse: collapse;
                width: 100%;
                background-color: white;
                box-shadow: 0 2px 4px rgba(0,0,0,0.1);
                margin-top: 20px;
            }}
            th {{
                background-color: #4CAF50;
                color: white;
                text-align: left;
                padding: 12px;
                font-weight: bold;
                position: sticky;
                top: 0;
                z-index: 10;
            }}
            td {{
                border: 1px solid #ddd;
                padding: 8px;
                text-align: left;
            }}
            tr:nth-child(even) {{
                background-color: #f9f9f9;
            }}
            tr:hover {{
                background-color: #f5f5f5;
            }}
            .index-col {{
                font-weight: bold;
                background-color: #e8f5e9;
                width: 250px;
            }}
            .header-row {{
                background-color: #2196F3 !important;
            }}
            .count-row {{
                background-color: #FFC107 !important;
                font-weight: bold;
            }}
            .significant {{
                background-color: #ffeb3b;
                font-weight: bold;
            }}
            .p-value {{
                text-align: center;
                font-weight: bold;
            }}
            .group-header {{
                background-color: #e0e0e0;
                font-weight: bold;
                font-style: italic;
            }}
        </style>
    </head>
    <body>
        <h1>{title}</h1>
        <h3>{subtitle}</h3>
    """

    # Convert DataFrame to HTML with custom formatting
    html_table = "<table>\n"

    # Add header row
    html_table += "<tr>"
    html_table += "<th>Variable</th>"
    for col in df.columns:
        html_table += f"<th>{col}</th>"
    html_table += "</tr>\n"

    # Add data rows
    for idx, row in df.iterrows():
        # Check if this is a special row
        row_class = ""
        if idx == "n":
            row_class = "count-row"
        elif isinstance(idx, str) and idx.lower() == "header":
            row_class = "header-row"

        html_table += f"<tr class='{row_class}'>"

        # Format index
        if isinstance(idx, tuple):
            # Handle multi-index
            index_text = idx[0] if idx[1] == '' else f"{idx[0]} - {idx[1]}"
        else:
            index_text = str(idx)

        # Check if this is a group header
        if index_text.endswith(', n (%)') and row.isna().all():
            html_table += f"<td colspan='{len(df.columns)+1}' class='group-header'>{index_text}</td>"
        else:
            html_table += f"<td class='index-col'>{index_text}</td>"

            # Add data cells
            for col in df.columns:
                cell_value = row[col]
                cell_class = ""

                # Highlight significant p-values
                if highlight_significant and p_value_col and col == p_value_col:
                    try:
                        if cell_value != "" and cell_value != "NA":
                            if cell_value == "<0.001" or (isinstance(cell_value, str) and float(cell_value) < 0.05):
                                cell_class = "significant p-value"
                            else:
                                cell_class = "p-value"
                    except:
                        cell_class = "p-value"

                # Handle NaN values
                if pd.isna(cell_value):
                    cell_value = ""

                html_table += f"<td class='{cell_class}'>{cell_value}</td>"

        html_table += "</tr>\n"

    html_table += "</table>"

    # Close HTML
    html_content += html_table
    html_content += """
    </body>
    </html>
    """

    return html_content

def save_tables_as_html():
    """Convert CSV tables to nicely formatted HTML"""

    print("\nGenerating HTML versions of tables...")

    # 1. Table 1 - All subcohorts
    try:
        table1_df = pd.read_csv(f'{output_folder}/final/table1_subgroups2_{pyCLIF.helper["site_name"].lower()}.csv', index_col=0)

        # Create subtitle with cohort definitions
        subtitle1 = """
        Pre-24h: Baseline before CRRT | Post-24h: 0-24 hours after CRRT start | Post-72h: 24-72 hours after CRRT start<br>
        Survivors: Alive at end of window | Non-survivors: Died within window | On/Off CRRT: Status at end of window
        """

        html1 = create_html_table_with_styling(
            table1_df,
            f"Table 1: Patient Characteristics by Time Window and Outcome - {pyCLIF.helper['site_name']}",
            subtitle=subtitle1
        )

        with open(f'{output_folder}/final/table1_subgroups2_{pyCLIF.helper["site_name"].lower()}.html', 'w') as f:
            f.write(html1)
        print("  ✓ Saved table1_all_subcohorts.html")
    except Exception as e:
        print(f"  ✗ Error creating Table 1 HTML: {e}")

    # 2. Table 2 - By location with p-values
    try:
        table2_df = pd.read_csv(f'{output_folder}/final/table2_by_location_with_stats_{pyCLIF.helper["site_name"].lower()}.csv', index_col=0)

        subtitle2 = ("Patient characteristics at 72 hours post-CRRT by ICU location type. P-values from ANOVA (continuous) or "
                    "Chi-square (categorical) tests.")

        html2 = create_html_table_with_styling(
            table2_df,
            f"Table 2: Patient Characteristics by ICU Location Type - {pyCLIF.helper['site_name']}",
            subtitle=subtitle2,
            highlight_significant=True,
            p_value_col='p_value'
        )

        with open(f'{output_folder}/final/table2_by_location_with_stats.html', 'w') as f:
            f.write(html2)
        print("  ✓ Saved table2_by_location_with_stats.html")
    except Exception as e:
        print(f"  ✗ Error creating Table 2 HTML: {e}")

    # 3. Create a combined HTML with both tables
    try:
        combined_html = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>CRRT Epidemiology Tables - {pyCLIF.helper['site_name']}</title>
            <style>
                body {{
                    font-family: Arial, sans-serif;
                    margin: 20px;
                    background-color: #f5f5f5;
                }}
                .table-container {{
                    margin-bottom: 50px;
                }}
                h1 {{
                    color: #333;
                    text-align: center;
                    margin-bottom: 5px;
                }}
                h2 {{
                    color: #444;
                    margin-top: 40px;
                    border-bottom: 2px solid #4CAF50;
                    padding-bottom: 10px;
                }}
                h3 {{
                    color: #666;
                    text-align: center;
                    margin-top: 5px;
                    font-weight: normal;
                }}
                table {{
                    border-collapse: collapse;
                    width: 100%;
                    background-color: white;
                    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
                    margin-top: 20px;
                }}
                th {{
                    background-color: #4CAF50;
                    color: white;
                    text-align: left;
                    padding: 12px;
                    font-weight: bold;
                }}
                td {{
                    border: 1px solid #ddd;
                    padding: 8px;
                    text-align: left;
                }}
                tr:nth-child(even) {{
                    background-color: #f9f9f9;
                }}
                tr:hover {{
                    background-color: #f5f5f5;
                }}
                .index-col {{
                    font-weight: bold;
                    background-color: #e8f5e9;
                    width: 250px;
                }}
                .count-row {{
                    background-color: #FFC107 !important;
                    font-weight: bold;
                }}
                .significant {{
                    background-color: #ffeb3b;
                    font-weight: bold;
                }}
                .p-value {{
                    text-align: center;
                    font-weight: bold;
                }}
            </style>
        </head>
        <body>
            <h1>CRRT Epidemiology Analysis - {pyCLIF.helper['site_name']}</h1>
            
            <div class="table-container">
                <h2>Table 1: Patient Characteristics by Time Window and Outcome</h2>
                <h3>Pre-24h: Baseline before CRRT | Post-24h: 0-24 hours after CRRT start | Post-72h: 24-72 hours after CRRT start<br>
                Survivors: Alive at end of window | Non-survivors: Died within window | On/Off CRRT: Status at end of window</h3>
                {table1_df.to_html(classes='styled-table', escape=False)}
            </div>
            
            <div class="table-container">
                <h2>Table 2: Patient Characteristics by ICU Location Type (72 hours post-CRRT)</h2>
                <h3>P-values from ANOVA (continuous) or Chi-square (categorical) tests. Highlighted values indicate p < 0.05</h3>
                {table2_df.to_html(classes='styled-table', escape=False)}
            </div>
        </body>
        </html>
        """

        with open(f'{output_folder}/final/all_tables_combined.html', 'w') as f:
            f.write(combined_html)
        print("  ✓ Saved all_tables_combined.html")

    except Exception as e:
        print(f"  ✗ Error creating combined HTML: {e}")

# Call the function to generate HTML tables
save_tables_as_html()