# Epidemiology of CRRT- Analysis

Author: Kaveri Chhikara

Run this script after 01_cohort_identification

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
from datetime import datetime, timedelta
import json
import warnings
import pyarrow
from tableone import TableOne
import seaborn as sns
import sofa_score  

warnings.filterwarnings('ignore')

import pyCLIF
output_folder = '../output'
## import outlier json
with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
    outlier_cfg = json.load(f)

import os
output_dir = "../output/final"
os.makedirs(output_dir, exist_ok=True)

# Load CLIF wide and other cohort datasets

In [None]:
all_ids_df = pd.read_parquet(f'{output_folder}/intermediate/all_ids.parquet')
adt_final_df = pd.read_parquet(f'{output_folder}/intermediate/adt_final.parquet')
clif_wide_df = pd.read_parquet(f'{output_folder}/intermediate/clif_wide.parquet')
crrt_df = pd.read_parquet(f'{output_folder}/intermediate/crrt_df.parquet')

In [None]:
print("=== all_ids_df dtypes ===")
print(all_ids_df.dtypes)
print("\n=== adt_final_df dtypes ===") 
print(adt_final_df.dtypes)
print("\n=== clif_wide_df dtypes ===")
print(clif_wide_df.dtypes)
print("\n=== crrt_df dtypes ===")
print(crrt_df.dtypes)

# Combine CLIF wide with ADT

Combine CLIF wide with ADT and forward fill location category and location type columns to get patient location info at each time point

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 1) ANNOTATE CLIF-WIDE RECORDS WITH ADT LOCATION (FORWARD-FILL ONLY)
# ─────────────────────────────────────────────────────────────────────────────
adt_int = adt_final_df[[
    "encounter_block",
    "in_dttm", "out_dttm",
    "location_category", "location_type"
]]

# 1a) Every flowsheet row matched to all ADT intervals for that encounter
merged = clif_wide_df.merge(adt_int, on="encounter_block", how="left")

# 1b) Keep only rows where recorded_dttm ∈ [in_dttm, out_dttm]
mask = (
    (merged["recorded_dttm"] >= merged["in_dttm"]) &
    (merged["recorded_dttm"] <= merged["out_dttm"])
)
annot = merged.loc[mask, 
    ["encounter_block", "recorded_dttm", "location_category", "location_type"]
].copy()

# 1c) Forward-fill per encounter_block
annot = (
    annot
    .sort_values(["encounter_block","recorded_dttm"])
    .groupby("encounter_block", as_index=False)
    .apply(lambda df: df.ffill())
    .dropna(subset=["location_category","location_type"], how="all")
    # now has a flat index: (encounter_block, recorded_dttm, ...)
)

# 1d) Re-index `annot` on the two keys
annot_indexed = annot.set_index(["encounter_block","recorded_dttm"])[
    ["location_category","location_type"]
]

# 1e) Join back onto `clif_wide_df`
clif_wide_df = (
    clif_wide_df
      .set_index(["encounter_block","recorded_dttm"])
      .join(annot_indexed, how="left")
      .reset_index()
)

# Now clif_wide_annot has the ADT location forward-filled from the last interval that
# covered each timestamp, with no peeking into future ADT rows.

# Table One

* Pre- 24 hr CRRT start
* Post-24 hr CRRT start
* Post-72 hr CRRT start

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 0) CONSTANTS & INPUTS
# ─────────────────────────────────────────────────────────────────────────────
output_folder = "../output"
os.makedirs(os.path.join(output_folder, "final"), exist_ok=True)

# path & mapping for SOFA
tables_path = pyCLIF.helper['tables_path']
id_mappings = all_ids_df[['encounter_block','hospitalization_id']].drop_duplicates()

# lists of variables
demog_cols    = ["age_at_admission","sex_category","race_category","ethnicity_category"]
adt_cols      = ["location_category","location_type"]
device_col    = "device_category"

continuous_vars = [
    # vasoactives
    "angiotensin","dobutamine","dopamine","epinephrine",
    "norepinephrine","phenylephrine","vasopressin",
    # labs
    "bicarbonate","bun","calcium_total","chloride","creatinine","magnesium",
    "glucose_serum","lactate","potassium","sodium","ph_arterial",
    # vents
    "fio2_set","peep_set","resp_rate_set","tidal_volume_set",
    "pressure_control_set","pressure_support_set","peak_inspiratory_pressure_set",
]

crrt_vars = [
    "blood_flow_rate",
    "pre_filter_replacement_fluid_rate",
    "post_filter_replacement_fluid_rate",
    "dialysate_flow_rate",
    "ultrafiltration_out",
]

# ─────────────────────────────────────────────────────────────────────────────
# 1) FIRST CRRT START PER ENCOUNTER
# ─────────────────────────────────────────────────────────────────────────────
first_crrt = (
    crrt_df
    .groupby("encounter_block", as_index=False)["recorded_dttm"]
    .min()
    .rename(columns={"recorded_dttm":"first_crrt_time"})
)

# pull demographics + death
demog = (
    all_ids_df
    .set_index("encounter_block")[demog_cols + ["death_dttm"]]
)

# ─────────────────────────────────────────────────────────────────────────────
# 2) helper: summarize any window
# ─────────────────────────────────────────────────────────────────────────────
def summarize_window(start_offset_h: int, end_offset_h: int, include_crrt: bool):
    """
    Summarize for each encounter_block in [first_crrt_time+start_offset_h, first_crrt_time+end_offset_h]:
     - demographics + died_within_window
     - median [Q1,Q3] of continuous_vars
     - last non-null ADT location + type
     - last non-null device_category
     - SOFA components + total
     - if include_crrt: last CRRT mode + median [Q1,Q3] of crrt_vars
    Returns DataFrame indexed by encounter_block.
    """

    # 2a) window definitions
    start_off = pd.Timedelta(hours=start_offset_h)
    end_off   = pd.Timedelta(hours=end_offset_h)
    bounds    = first_crrt.copy()
    bounds["win_start"] = bounds["first_crrt_time"] + start_off
    bounds["win_end"]   = bounds["first_crrt_time"] + end_off
    bnd        = bounds.set_index("encounter_block")

    # 2b) CLIF-wide flowsheet slice
    cw = (
        clif_wide_df
        .merge(bounds[["encounter_block","win_start","win_end"]], on="encounter_block", how="inner")
    )
    cw = cw[(cw.recorded_dttm >= cw.win_start) & (cw.recorded_dttm <= cw.win_end)]

    # continuous vars: median & IQR
    med = cw.groupby("encounter_block")[continuous_vars].median()
    q1  = cw.groupby("encounter_block")[continuous_vars].quantile(0.25)
    q3  = cw.groupby("encounter_block")[continuous_vars].quantile(0.75)
    cont = (
        med.add_suffix("_median")
           .join(q1.add_suffix("_q1"))
           .join(q3.add_suffix("_q3"))
    )

    # ADT location/type: last non-null in window
    loc_win = cw.dropna(subset=adt_cols, how="all")
    last_loc = (
        loc_win.sort_values("recorded_dttm")
               .groupby("encounter_block")[adt_cols]
               .last()
    )

    # device_category: last non-null in window
    dev_win = cw.dropna(subset=[device_col])
    last_dev = (
        dev_win.sort_values("recorded_dttm")
               .groupby("encounter_block")[[device_col]]
               .last()
    )

    # mortality flag
    death_flag = (
        demog["death_dttm"]
        .between(bnd["win_start"], bnd["win_end"])
        .rename("died_within_window")
    )

    # assemble core
    df = (
        bnd
        .join(demog.drop(columns="death_dttm"))
        .join(death_flag)
        .join(cont)
        .join(last_loc)
        .join(last_dev)
    )

    # 2c) SOFA
    sofa_in = bounds[["encounter_block"]].copy()
    sofa_in["start_dttm"] = bounds["win_start"]
    sofa_in["stop_dttm"]  = bounds["win_end"]
    sofa_out = sofa_score.compute_sofa(
        ids_w_dttm            = sofa_in,
        tables_path           = tables_path,
        use_hospitalization_id= False,
        id_mapping            = id_mappings,
        helper_module         = pyCLIF,
        output_filepath       = None
    )
    sofa_cols = [
      "sofa_cv_97","sofa_coag","sofa_renal",
      "sofa_liver","sofa_resp","sofa_cns","sofa_total"
    ]
    sofa_df = sofa_out.set_index("encounter_block")[sofa_cols]
    df = df.join(sofa_df)

    # 2d) CRRT-specific (optional)
    if include_crrt:
        cr = crrt_df.merge(bounds.reset_index(), on="encounter_block", how="inner")
        cr_win = cr[(cr.recorded_dttm >= cr.win_start) & (cr.recorded_dttm <= cr.win_end)]

        # last CRRT mode
        mode_win = cr_win.dropna(subset=["crrt_mode_category"])
        last_mode = (
            mode_win.sort_values("recorded_dttm")
                    .groupby("encounter_block")["crrt_mode_category"]
                    .last()
                    .rename("last_crrt_mode")
        )

        # CRRT cont. vars
        cm = cr_win.groupby("encounter_block")[crrt_vars].median()
        c1 = cr_win.groupby("encounter_block")[crrt_vars].quantile(0.25)
        c3 = cr_win.groupby("encounter_block")[crrt_vars].quantile(0.75)
        crrt_cont = (
            cm.add_suffix("_median")
              .join(c1.add_suffix("_q1"))
              .join(c3.add_suffix("_q3"))
        )

        df = df.join(last_mode).join(crrt_cont)

    return df.drop(columns=["first_crrt_time","win_start","win_end"])


# ─────────────────────────────────────────────────────────────────────────────
# 3) BUILD THE THREE WINDOWS
# ─────────────────────────────────────────────────────────────────────────────
pre24  = summarize_window(-24,  0, include_crrt=False)
post24 = summarize_window(  0, 24, include_crrt=True)
post72 = summarize_window(  0, 72, include_crrt=True)

# tag & stack
for df, lbl in [(pre24,"Pre-24h"),(post24,"Post-24h"),(post72,"Post-72h")]:
    df["window"] = lbl

combined = pd.concat([pre24,post24,post72], axis=0).reset_index()
combined.to_csv(os.path.join(output_folder,"intermediate","combined_summary.csv"), index=False)

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 4) MAKE TABLEONE FOR EACH WINDOW (with SOFA)
# ─────────────────────────────────────────────────────────────────────────────
def make_tableone(summary_df, window_label):
    df = summary_df.reset_index()

    # base set of categorical & continuous variables
    cat_vars = [
      "sex_category","race_category","ethnicity_category",
      "location_category","location_type","device_category",
      "died_within_window"
    ]

    cont_vars = (
      ["age_at_admission"]
      + continuous_vars
      # add SOFA components here:
      + ["sofa_cv_97","sofa_coag","sofa_renal",
         "sofa_liver","sofa_resp","sofa_cns","sofa_total"]
    )

    # for post-windows also include CRRT mode + CRRT vars
    if "last_crrt_mode" in df.columns:
        cat_vars.append("last_crrt_mode")
        cont_vars += crrt_vars

    # flatten any *_median → base var and drop *_q1/_q3
    for v in continuous_vars + crrt_vars:
        m = f"{v}_median"
        if m in df:
            df[v] = df.pop(m)
        for suf in ("_q1","_q3"):
            col = v + suf
            if col in df:
                df.pop(col)

    # Round vasopressors to 4 decimal places
    vasopressors = ["angiotensin", "dobutamine", "dopamine", "epinephrine",
                    "norepinephrine", "phenylephrine", "vasopressin"]
    for v in vasopressors:
        if v in df.columns:
            df[v] = df[v].round(4)

    # build the TableOne; for all cont_vars it will automatically compute
    # median [IQR] because we pass them in `nonnormal=`
    tbl = TableOne(
      df,
      columns     = cat_vars + cont_vars,
      categorical = cat_vars,
      nonnormal   = cont_vars,
      groupby     = None
    )

    # save & return
    fname = f"table1_{window_label.replace(' ','_')}.csv"
    tbl.to_csv(os.path.join(output_folder,"final",fname))
    print("Table saved to:", f"{output_folder}/final/{fname}")
    return tbl

# regenerate your tables:
table_pre24  = make_tableone(pre24,  "Pre-24h")
table_post24 = make_tableone(post24, "Post-24h")
table_post72 = make_tableone(post72, "Post-72h")


In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 1) Extract the underlying pandas DataFrames from each TableOne
# ─────────────────────────────────────────────────────────────────────────────
df_pre24  = pd.DataFrame(table_pre24.tableone).rename(columns={0: "Pre-24h",   "Overall": "pre24_summary", "Missing": "pre24_missing"})
df_post24 = pd.DataFrame(table_post24.tableone).rename(columns={0: "Post-24h", "Overall": "post24_summary", "Missing": "post24_missing"})
df_post72 = pd.DataFrame(table_post72.tableone).rename(columns={0: "Post-72h", "Overall": "post72_summary", "Missing": "post72_missing"})

# ─────────────────────────────────────────────────────────────────────────────
# 2) Join them on their index (the "characteristic" labels) using post24/72 order
# ─────────────────────────────────────────────────────────────────────────────
# Get reference order from post24 or post72 if they exist
ref_order = None
if not df_post24.empty:
    ref_order = df_post24.index.tolist()
elif not df_post72.empty:
    ref_order = df_post72.index.tolist()

combined = (
    df_pre24
      .join(df_post24, how="outer")
      .join(df_post72, how="outer")
)

# Reorder rows if we have a reference order
if ref_order:
    # Only use existing indices from reference order
    valid_indices = [idx for idx in ref_order if idx in combined.index]
    # Add any remaining indices that weren't in reference
    remaining = combined.index.difference(valid_indices).tolist()
    combined = combined.reindex(valid_indices + remaining)

# ─────────────────────────────────────────────────────────────────────────────
# 3) Clean up the index name and display
# ─────────────────────────────────────────────────────────────────────────────
combined.index.name = "Characteristic"
combined.reset_index(inplace=True)

# Drop all *_missing columns
missing_cols = [col for col in combined.columns if col.endswith('_missing')]
combined = combined.drop(columns=missing_cols)

# ─────────────────────────────────────────────────────────────────────────────
# 4) Save to CSV if you like
# ─────────────────────────────────────────────────────────────────────────────

filename = f"table1_all_windows_{pyCLIF.helper['site_name']}.csv"
combined.to_csv(os.path.join(output_folder, "final", filename), index=False)

# CRRT summary by modes

In [None]:
# Get first settings for each patient
first_settings = crrt_df.sort_values(['encounter_block', 'recorded_dttm']).groupby(['encounter_block', 'crrt_mode_category']).first()

# Get numeric columns for analysis
numeric_cols = ['blood_flow_rate', 'pre_filter_replacement_fluid_rate', 
               'post_filter_replacement_fluid_rate', 'dialysate_flow_rate',
               'ultrafiltration_out']

# Calculate median and IQR of first settings by mode
first_settings_summary = first_settings[numeric_cols].groupby('crrt_mode_category').agg(['median', lambda x: x.quantile(0.25), lambda x: x.quantile(0.75)])
first_settings_summary.columns = first_settings_summary.columns.map(lambda x: f"{x[0]}_{x[1]}" if x[1] == 'median' else f"{x[0]}_q1" if x[1] == '<lambda_0>' else f"{x[0]}_q3")

# Calculate mean settings across entire hospitalization by mode
hosp_avg_settings = crrt_df.groupby(['encounter_block', 'crrt_mode_category'])[numeric_cols].mean()
mode_avg_settings = hosp_avg_settings.groupby('crrt_mode_category').agg(['mean', 'std'])
mode_avg_settings.columns = mode_avg_settings.columns.map(lambda x: f"{x[0]}_{x[1]}")

# Calculate duration for each mode
crrt_df['next_time'] = crrt_df.groupby('encounter_block')['recorded_dttm'].shift(-1)
crrt_df['duration_hrs'] = (crrt_df['next_time'] - crrt_df['recorded_dttm']).dt.total_seconds() / 3600

mode_duration = crrt_df.groupby(['encounter_block', 'crrt_mode_category'])['duration_hrs'].sum()
duration_summary = mode_duration.groupby('crrt_mode_category').agg(['mean', 'std', 'median', 
                                                                  lambda x: x.quantile(0.25),
                                                                  lambda x: x.quantile(0.75)])
duration_summary.columns = ['duration_mean_hrs', 'duration_std_hrs', 'duration_median_hrs', 
                          'duration_q1_hrs', 'duration_q3_hrs']

# Save results
first_settings_summary.to_csv(f'{output_folder}/final/crrt_first_settings.csv')
mode_avg_settings.to_csv(f'{output_folder}/final/crrt_average_settings.csv')
duration_summary.to_csv(f'{output_folder}/final/crrt_duration.csv')

In [None]:
# Step 1) Build the display DataFrame of formatted strings
params = [
    "blood_flow_rate",
    "pre_filter_replacement_fluid_rate",
    "post_filter_replacement_fluid_rate",
    "dialysate_flow_rate",
    "ultrafiltration_out",
]

# This will be our new table: rows = modes, cols = params
display_df = pd.DataFrame(index=first_settings_summary.index)

for p in params:
    med = first_settings_summary[f"{p}_median"]
    q1  = first_settings_summary[f"{p}_q1"]
    q3  = first_settings_summary[f"{p}_q3"]
    
    # format each row as "median [q1–q3]"
    display_df[p] = [
        f"{int(m):,} [{int(lo):,}–{int(hi):,}]" 
        if pd.notna(m) 
        else "NA"
        for m, lo, hi in zip(med, q1, q3)
    ]

# optional: rename columns to more human-friendly labels
display_df.columns = [
    "Blood flow (mL/hr)",
    "Pre-filter repl (mL/hr)",
    "Post-filter repl (mL/hr)",
    "Dialysate flow (mL/hr)",
    "Ultrafiltration (mL)"
]

# Step 2) Draw with matplotlib.table
fig, ax = plt.subplots(figsize=(10, 2 + 0.5 * len(display_df)))  # height scales with rows
ax.axis("off")

tbl = ax.table(
    cellText=display_df.values,
    rowLabels=display_df.index,
    colLabels=display_df.columns,
    cellLoc="center",
    rowLoc="center",
    loc="center"
)

tbl.auto_set_font_size(False)
tbl.set_fontsize(10)
tbl.scale(1, 1.5)  # stretch rows a bit

plt.title(f"First CRRT Settings by CRRT Mode in {pyCLIF.helper['site_name']}: (Median [IQR])", pad=20)
plt.tight_layout()
# Save the figure
plt.savefig(f"../output/final/graphs/first_crrt_settings.png", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Step 1) Build the display DataFrame of formatted strings
params = [
    "blood_flow_rate",
    "pre_filter_replacement_fluid_rate", 
    "post_filter_replacement_fluid_rate",
    "dialysate_flow_rate",
    "ultrafiltration_out"
]

# This will be our new table: rows = modes, cols = params
display_df = pd.DataFrame(index=mode_avg_settings.index)

for p in params:
    mean = mode_avg_settings[f"{p}_mean"]
    std = mode_avg_settings[f"{p}_std"]
    
    # format each row as "mean ± std" 
    display_df[p] = [
        f"{int(m):,}  ({int(s):,})" 
        if pd.notna(m)
        else "NA"
        for m, s in zip(mean, std)
    ]

# optional: rename columns to more human-friendly labels
display_df.columns = [
    "Blood flow (mL/hr)",
    "Pre-filter repl (mL/hr)", 
    "Post-filter repl (mL/hr)",
    "Dialysate flow (mL/hr)",
    "Ultrafiltration (mL)"
]

# Step 2) Draw with matplotlib.table
fig, ax = plt.subplots(figsize=(10, 2 + 0.5 * len(display_df)))  # height scales with rows
ax.axis("off")

tbl = ax.table(
    cellText=display_df.values,
    rowLabels=display_df.index,
    colLabels=display_df.columns,
    cellLoc="center",
    rowLoc="center",
    loc="center"
)

tbl.auto_set_font_size(False)
tbl.set_fontsize(10)
tbl.scale(1, 1.5)  # stretch rows a bit

plt.title(f"Average CRRT Settings by CRRT Mode in {pyCLIF.helper['site_name']}: Mean (SD)", pad=20)
plt.tight_layout()
# Save the figure
plt.savefig(f"../output/final/graphs/avg_crrt_settings.png", bbox_inches='tight', dpi=300)
plt.show()



In [None]:
# Add index as column for melt
duration_summary_reset = duration_summary.reset_index()

# Plot Duration on each CRRT mode
plt.figure(figsize=(10, 6))
sns.barplot(data=duration_summary_reset, 
            x='crrt_mode_category', y='duration_median_hrs', 
            color='skyblue')
plt.errorbar(x=np.arange(len(duration_summary_reset)), 
             y=duration_summary_reset['duration_median_hrs'], 
             yerr=[duration_summary_reset['duration_median_hrs'] - duration_summary_reset['duration_q1_hrs'], 
                   duration_summary_reset['duration_q3_hrs'] - duration_summary_reset['duration_median_hrs']], 
             fmt='none', c='black', capsize=5)
plt.title("Duration on CRRT Mode (Median and IQR)")
plt.ylabel("Hours")
plt.xticks(rotation=45)
plt.tight_layout()
# Save the figure before showing and closing
plt.savefig(f"{output_dir}/graphs/crrt_mode_duration.png")
plt.show()
plt.close()


# CRRT Mode Switches

In [None]:
# Analyze mode switches by looking at consecutive rows with different modes
crrt_df['prev_mode'] = crrt_df.groupby('encounter_block')['crrt_mode_category'].shift(1)
crrt_df['mode_switch'] = (
    (crrt_df['crrt_mode_category'] != crrt_df['prev_mode']) &
    (~crrt_df['prev_mode'].isna()) &
    (~crrt_df['crrt_mode_category'].isna())  # ← exclude transitions TO NaN
)

# Count switches between each mode pair
mode_switches = crrt_df[crrt_df['mode_switch']].groupby(['prev_mode', 'crrt_mode_category']).size()
mode_switches = mode_switches.reset_index()
mode_switches.columns = ['from_mode', 'to_mode', 'count']

# Create pivot table for display
mode_switches_pivot = mode_switches.pivot(index='from_mode', columns='to_mode', values='count').fillna(0)

print("\nCRRT Mode Switches:")
print(mode_switches_pivot)

# Visualize mode switches as a heatmap
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 8))
sns.heatmap(mode_switches_pivot, 
            annot=True, 
            fmt='.0f',
            cmap='Blues',
            cbar_kws={'label': f"Number of Mode Switches {pyCLIF.helper['site_name']}"})
plt.title('CRRT Mode Mode Switch Matrix')
plt.xlabel('To Mode')
plt.ylabel('From Mode')
plt.tight_layout()
plt.savefig(f"{output_dir}/graphs/crrt_mode_transitions_heatmap.png")
plt.close()

# Calculate average number of mode switches per hospitalization
switches_per_hosp = crrt_df.groupby('encounter_block')['mode_switch'].sum()
avg_switches = switches_per_hosp.mean()
median_switches = switches_per_hosp.median()

print(f"\nAverage mode switches per hospitalization: {avg_switches:.2f}")
print(f"Median mode switches per hospitalization: {median_switches:.2f}")

# Save mode switches matrix
mode_switches_pivot.to_csv(f"{output_dir}/crrt_mode_switches.csv")

# Save summary statistics
with open(f"{output_dir}/crrt_mode_switches_summary.txt", "w") as f:
    f.write(f"Average mode switches per hospitalization: {avg_switches:.2f}\n")
    f.write(f"Median mode switches per hospitalization: {median_switches:.2f}\n")

In [None]:
# Ensure mode switches are correctly flagged
crrt_df['prev_mode'] = crrt_df.groupby('encounter_block')['crrt_mode_category'].shift(1)
crrt_df['mode_switch'] = (
    crrt_df['crrt_mode_category'] != crrt_df['prev_mode']
) & (~crrt_df['prev_mode'].isna())

# Count unique encounters where a switch occurred between each mode pair
mode_switch_encounters = (
    crrt_df[crrt_df['mode_switch']]
    .groupby(['prev_mode', 'crrt_mode_category'])['encounter_block']
    .nunique()
    .reset_index(name='encounter_count')
)

# Pivot for matrix view
switch_matrix = mode_switch_encounters.pivot(
    index='prev_mode',
    columns='crrt_mode_category',
    values='encounter_count'
).fillna(0).astype(int)

print("\nUnique Encounters with CRRT Mode Switches:")
print(switch_matrix)

# Save to CSV
switch_matrix.to_csv("../output/final/unique_encounter_mode_switches.csv")

# Create heatmap for mode switches by encounter
plt.figure(figsize=(10, 8))
sns.heatmap(switch_matrix,
            annot=True,
            fmt='d',
            cmap='Blues',
            cbar_kws={'label': 'Number of Unique Encounters with Mode Switches'})
plt.title(f"CRRT Mode Switches by Unique Encounters {pyCLIF.helper['site_name']}")
plt.xlabel('To Mode')
plt.ylabel('From Mode')
plt.tight_layout()
plt.savefig(f"{output_dir}/graphs/crrt_mode_transitions_by_encounter_heatmap.png")
plt.close()


# Hourly Trends

In [None]:
# 1) Build a relative‐hour field where hour=0 is 24h before CRRT
df = (
    clif_wide_df
    .merge(first_crrt, on="encounter_block", how="inner")
)
# origin = first_crrt_time - 24h
df["origin"] = df["first_crrt_time"] - pd.Timedelta(hours=24)
df["rel_hr"] = ((df["recorded_dttm"] - df["origin"]) 
                / pd.Timedelta(hours=1))
# filter to [0, 96]
df = df[(df["rel_hr"] >= 0) & (df["rel_hr"] <= 96)]

# assign integer hour bins
df["hour_bin"] = df["rel_hr"].floordiv(1).astype(int)

# 2) Compute per‐hour median, Q1, Q3 for each variable
agg = {}
for v in continuous_vars:
    agg[v] = ["median", lambda x: x.quantile(0.25), lambda x: x.quantile(0.75)]
hourly = df.groupby("hour_bin").agg(agg)
# flatten columns
hourly.columns = pd.Index([f"{var}_{stat}"
                           for var,stat in hourly.columns],
                          name=None)

# 3) Plot grid of time‐series with shaded IQR and CRRT‐band
n = len(continuous_vars)
cols = 4
rows = (n + cols - 1)//cols
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 2.5*rows), sharex=True)
for ax, var in zip(axes.flat, continuous_vars):
    m   = hourly[f"{var}_median"]
    q1  = hourly[f"{var}_<lambda_0>"]  # this is the 0.25 quantile
    q3  = hourly[f"{var}_<lambda_1>"]  # this is the 0.75 quantile

    ax.plot(hourly.index, m)
    ax.fill_between(hourly.index, q1, q3, alpha=0.3)
    # highlight the first 24h of CRRT → that's rel_hr 24→48
    ax.axvspan(24, 48, color="C1", alpha=0.1)
    ax.set_title(f"{var} [median, IQR]")
    ax.set_xlim(0,96)
    ax.set_xlabel("Hours")
    ax.set_ylabel(var)

# hide any empty subplots
for extra_ax in axes.flat[n:]:
    extra_ax.set_visible(False)

fig.tight_layout()
plt.suptitle("Trajectories of ICU Labs/Vasos/Vent Settings from CRRT-24hrs to CRRT+72hrs (Median[IQR])", y=1.02)
plt.savefig("../output/final/graphs/trajectories_median_iqr.png", bbox_inches='tight', dpi=300)
plt.show()

# save hourly data
hourly = hourly.rename(
    columns={f"{v}_<lambda_0>": f"{v}_q1",
             f"{v}_<lambda_1>": f"{v}_q3"}
)

# 2) reset_index so `hour_bin` becomes a column again
site_summary = hourly.reset_index().rename(columns={"hour_bin": "hour"})
site_summary.to_csv(os.path.join(output_folder, "final", "site_hourly_summary_median.csv"), index=False)

In [None]:
# 1) Build a relative‐hour field where hour=0 is 24h before CRRT
df = (
    clif_wide_df
    .merge(first_crrt, on="encounter_block", how="inner")
)
# origin = first_crrt_time - 24h
df["origin"] = df["first_crrt_time"] - pd.Timedelta(hours=24)
df["rel_hr"] = ((df["recorded_dttm"] - df["origin"]) 
                / pd.Timedelta(hours=1))
# filter to [0, 96]
df = df[(df["rel_hr"] >= 0) & (df["rel_hr"] <= 96)]

# assign integer hour bins
df["hour_bin"] = df["rel_hr"].floordiv(1).astype(int)

# 2) Compute per‐hour mean and std for each variable
agg = {}
for v in continuous_vars:
    agg[v] = ["mean", "std"]
hourly = df.groupby("hour_bin").agg(agg)
# flatten columns
hourly.columns = pd.Index([f"{var}_{stat}"
                           for var,stat in hourly.columns],
                          name=None)

# 3) Plot grid of time‐series with shaded ±1 SD and CRRT‐band
n = len(continuous_vars)
cols = 4
rows = (n + cols - 1)//cols
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 2.5*rows), sharex=True)
for ax, var in zip(axes.flat, continuous_vars):
    mean = hourly[f"{var}_mean"]
    sd = hourly[f"{var}_std"]

    ax.plot(hourly.index, mean)
    ax.fill_between(hourly.index, mean-sd, mean+sd, alpha=0.3)
    # highlight the first 24h of CRRT → that's rel_hr 24→48
    ax.axvspan(24, 48, color="C1", alpha=0.1)
    ax.set_title(f"{var} [mean ± SD]")
    ax.set_xlim(0,96)
    ax.set_xlabel("Hours")
    ax.set_ylabel(var)

# hide any empty subplots
for extra_ax in axes.flat[n:]:
    extra_ax.set_visible(False)

fig.tight_layout()
plt.suptitle("Trajectories of ICU Labs/Vasos/Vent Settings from CRRT-24hrs to CRRT+72hrs (Mean ± SD)", y=1.02)
plt.savefig("../output/final/graphs/trajectories_mean_sd.png", bbox_inches='tight', dpi=300)
plt.show()

# save hourly data
hourly = hourly.rename(
    columns={f"{v}_<lambda_0>": f"{v}_q1",
             f"{v}_<lambda_1>": f"{v}_q3"}
)

# 2) reset_index so `hour_bin` becomes a column again
site_summary = hourly.reset_index().rename(columns={"hour_bin": "hour"})
site_summary.to_csv(os.path.join(output_folder, "final", "site_hourly_summary_mean.csv"), index=False)