# Process CRRT Therapy Table

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gc
from pathlib import Path
import json
import pyarrow
import warnings
import clifpy
from typing import Union
from tqdm import tqdm

import sys
import clifpy
import os

print("=== Environment Verification ===")
print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")
print(f"clifpy version: {clifpy.__version__}")
print(f"clifpy location: {clifpy.__file__}")

print("\n=== Python Path Check ===")
local_clifpy_path = "/Users/kavenchhikara/Desktop/CLIF/CLIFpy"
if any(local_clifpy_path in path for path in sys.path):
    print("⚠️  WARNING: Local CLIFpy still in path!")
    for path in sys.path:
        if local_clifpy_path in path:
            print(f"   Found: {path}")
else:
    print("✅ Clean environment - no local CLIFpy in path")

print(f"\n=== Working Directory ===")
print(f"Current directory: {os.getcwd()}")

In [None]:
# Load configuration
config_path = "../config/config.json"
with open(config_path, 'r') as f:
    config = json.load(f)

## import outlier json
# with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
#     outlier_cfg = json.load(f)

print(f"\n=� Configuration:")
print(f"   Data directory: {config['tables_path']}")
print(f"   File type: {config['file_type']}")
print(f"   Timezone: {config['timezone']}")

# Load intermediate data

In [None]:
# Read in the dataframes saved as parquet files in 00_cohort.ipynb
cohort_df = pd.read_parquet("../output/intermediate/cohort_df.parquet")
outcomes_df = pd.read_parquet("../output/intermediate/outcomes_df.parquet")
weight_df = pd.read_parquet("../output/intermediate/weight_df.parquet")
crrt_at_initiation = pd.read_parquet("../output/intermediate/crrt_at_initiation.parquet")
crrt_initiation = pd.read_parquet("../output/intermediate/crrt_initiation.parquet")
index_crrt_df = pd.read_parquet("../output/intermediate/index_crrt_df.parquet")

# CRRT Dose

In [None]:
# ============================================================================
# Calculate 30-Day Mortality from CRRT Initiation
# ============================================================================
print("\nCalculating 30-day mortality from CRRT initiation...")

index_crrt_df = index_crrt_df.merge(
    outcomes_df[['encounter_block', 'death_dttm', 'died', 'in_hosp_death']],
    on='encounter_block',
    how='left'
)

index_crrt_df['days_crrt_to_death'] = np.where(
    index_crrt_df['death_dttm'].notna(),
    (index_crrt_df['death_dttm'] - index_crrt_df['crrt_initiation_time']).dt.total_seconds() / (24 * 3600),
    np.nan
)

index_crrt_df['death_30d_from_crrt'] = np.where(
    (index_crrt_df['days_crrt_to_death'].notna()) &
    (index_crrt_df['days_crrt_to_death'] <= 30),
    1,
    0
)

print(f"   Deaths within 30 days of CRRT initiation: {index_crrt_df['death_30d_from_crrt'].sum():,}")
print(f"   30-day mortality rate: {index_crrt_df['death_30d_from_crrt'].mean() * 100:.1f}%")
print(f"   In-hospital mortality: {index_crrt_df['in_hosp_death'].sum():,} ({index_crrt_df['in_hosp_death'].mean() * 100:.1f}%)")

# ============================================================================
#  Summary Statistics
# ============================================================================
print("\n" + "=" * 80)
print("CRRT Dose Summary Statistics at Initiation")
print("=" * 80)

dose_available = index_crrt_df['crrt_dose_ml_kg_hr'].notna()
print(f"\n   CRRT doses calculated: {dose_available.sum():,}/{len(index_crrt_df):,} ({dose_available.mean()*100:.1f}%)")

if dose_available.sum() > 0:
    dose_stats = index_crrt_df['crrt_dose_ml_kg_hr'].describe()
    print(f"\n   Dose statistics (mL/kg/hr):")
    print(f"     Mean ± SD: {dose_stats['mean']:.1f} ± {dose_stats['std']:.1f}")
    print(f"     Median [IQR]: {dose_stats['50%']:.1f} [{dose_stats['25%']:.1f}-{dose_stats['75%']:.1f}]")
    print(f"     Range: {dose_stats['min']:.1f} - {dose_stats['max']:.1f}")

    print(f"\n   Clinical Target Analysis:")
    target_range = ((index_crrt_df['crrt_dose_ml_kg_hr'] >= 20) &
                    (index_crrt_df['crrt_dose_ml_kg_hr'] <= 25)).sum()
    below_target = (index_crrt_df['crrt_dose_ml_kg_hr'] < 20).sum()
    above_target = (index_crrt_df['crrt_dose_ml_kg_hr'] > 25).sum()

    print(f"     Below target (<20 mL/kg/hr): {below_target:,} ({below_target/dose_available.sum()*100:.1f}%)")
    print(f"     Within target (20-25 mL/kg/hr): {target_range:,} ({target_range/dose_available.sum()*100:.1f}%)")
    print(f"     Above target (>25 mL/kg/hr): {above_target:,} ({above_target/dose_available.sum()*100:.1f}%)")

print(f"\n   Dose and Mortality by CRRT Mode:")
for mode in index_crrt_df['crrt_mode_category'].unique():
    if pd.notna(mode):
        mode_data = index_crrt_df[index_crrt_df['crrt_mode_category'] == mode]
        mode_doses = mode_data['crrt_dose_ml_kg_hr'].dropna()

        if len(mode_data) > 0:
            print(f"     {mode.upper()}:")
            print(f"       Count: {len(mode_data):,} ({len(mode_data)/len(index_crrt_df)*100:.1f}%)")
            if len(mode_doses) > 0:
                print(f"       Mean dose: {mode_doses.mean():.1f} mL/kg/hr")
                print(f"       Median dose: {mode_doses.median():.1f} mL/kg/hr")
            print(f"       30-day mortality: {mode_data['death_30d_from_crrt'].sum():,} ({mode_data['death_30d_from_crrt'].mean()*100:.1f}%)")

# ============================================================================
# STEP 10: Create Final Analysis Dataset
# ============================================================================
print("\n" + "=" * 80)
print("Creating Final Analysis Dataset")
print("=" * 80)

analysis_columns = [
    'encounter_block',
    'hospitalization_id',
    'crrt_initiation_time',
    'crrt_mode_category',
    'weight_kg',
    'total_flow_rate',
    'crrt_dose_ml_kg_hr',
    'total_flow_rate_full',
    'crrt_dose_ml_kg_hr_full', 
    'dialysate_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate',
    'ultrafiltration_out',
    'died',
    'in_hosp_death',
    'death_30d_from_crrt',
    'days_crrt_to_death',
    'death_dttm',
    'duration_days', 'imv_duration_days'
]

crrt_analysis_df = index_crrt_df[analysis_columns].copy()

print(f"\n   Final analysis dataset created:")
print(f"     Total records: {len(crrt_analysis_df):,}")
print(f"     Unique encounter blocks: {crrt_analysis_df['encounter_block'].nunique():,}")
print(f"     Records with valid dose: {crrt_analysis_df['crrt_dose_ml_kg_hr'].notna().sum():,}")
print(f"     Records with weight: {crrt_analysis_df['weight_kg'].notna().sum():,}")
print(f"     30-day mortality: {crrt_analysis_df['death_30d_from_crrt'].sum():,} ({crrt_analysis_df['death_30d_from_crrt'].mean()*100:.1f}%)")

print("\n✅ CRRT analysis completed successfully!")
print(f"   Dataset 'crrt_analysis_df' ready for further analysis")

In [None]:
crrt_analysis_df.columns

# SOFA

In [None]:
sofa_cohort_df = index_crrt_df[['hospitalization_id', 'encounter_block', 'crrt_initiation_time']].copy()
sofa_cohort_df['start_dttm'] = sofa_cohort_df['crrt_initiation_time'] + pd.Timedelta(hours=-12)
sofa_cohort_df['end_dttm'] = sofa_cohort_df['crrt_initiation_time'] + pd.Timedelta(hours=3)

# Keep only required columns
sofa_cohort_df = sofa_cohort_df[['hospitalization_id','encounter_block', 'start_dttm', 'end_dttm']]
sofa_cohort_ids = cohort_df['hospitalization_id'].astype(str).unique().tolist()

In [None]:
import polars as pl
import sofa_calculator
import importlib
import sys
importlib.reload(sofa_calculator)
from sofa_calculator import compute_sofa_polars


# Rename columns to match sofa_calculator requirements
sofa_input_df = sofa_cohort_df.rename(columns={
    'start_time': 'start_dttm',
    'end_time': 'end_dttm'
})

# Convert pandas → Polars
sofa_input_pl = pl.from_pandas(sofa_input_df)

#  Call SOFA Calculator

print("Calculating SOFA scores...")
sofa_scores_pl = compute_sofa_polars(
    data_directory=config['tables_path'],
    cohort_df=sofa_input_pl,
    filetype=config['file_type'],
    id_name='encounter_block',  # Group by encounter blocks
    extremal_type='worst',
    fill_na_scores_with_zero=False,  # Leave null as requested
    remove_outliers=True,
    timezone=config['timezone']
)

# Convert Results Back to Pandas

# Convert Polars → pandas
sofa_scores_df = sofa_scores_pl.to_pandas()

print(f"✓ SOFA scores calculated for {len(sofa_scores_df)} encounter blocks")
print(f"  Mean total SOFA: {sofa_scores_df['sofa_total'].mean():.2f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create histogram of SOFA total scores
fig, ax = plt.subplots(figsize=(10, 6))

# Plot histogram
ax.hist(sofa_scores_df['sofa_total'].dropna(),
        bins=range(0, int(sofa_scores_df['sofa_total'].max()) + 2),
        edgecolor='black',
        alpha=0.7,
        color='steelblue')

# Add vertical line for mean
mean_sofa = sofa_scores_df['sofa_total'].mean()
ax.axvline(mean_sofa, color='red', linestyle='--', linewidth=2,
            label=f'Mean: {mean_sofa:.1f}')

# Add vertical line for median
median_sofa = sofa_scores_df['sofa_total'].median()
ax.axvline(median_sofa, color='orange', linestyle='--', linewidth=2,
            label=f'Median: {median_sofa:.1f}')

# Labels and title
ax.set_xlabel('SOFA Total Score', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Encounters', fontsize=12, fontweight='bold')
ax.set_title('Distribution of SOFA Total Scores at CRRT Initiation',
            fontsize=14, fontweight='bold', pad=20)

# Add grid
ax.grid(axis='y', alpha=0.3, linestyle='--')

# Add statistics text box
stats_text = f'n = {sofa_scores_df["sofa_total"].notna().sum()}\n'
stats_text += f'Mean ± SD: {mean_sofa:.1f} ± {sofa_scores_df["sofa_total"].std():.1f}\n'
stats_text += f'Median [IQR]: {median_sofa:.1f} [{sofa_scores_df["sofa_total"].quantile(0.25):.1f}-{sofa_scores_df["sofa_total"].quantile(0.75):.1f}]'

ax.text(0.98, 0.97, stats_text,
        transform=ax.transAxes,
        fontsize=10,
        verticalalignment='top',
        horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Legend
# ax.legend(loc='upper right', fontsize=10)

# Adjust layout
plt.tight_layout()

# Save figure
plt.savefig('../output/final/graphs/sofa_total_distribution.png', dpi=300, bbox_inches='tight')
print("✓ Histogram saved to output/final/graphs/sofa_total_distribution.png")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
outcomes_with_sofa = outcomes_df.merge(sofa_scores_df, on='encounter_block', how='left')
# Calculate mortality rate by SOFA total score
mortality_by_sofa = outcomes_with_sofa.groupby('sofa_total').agg({
    'in_hosp_death': ['sum', 'count', 'mean']
}).reset_index()

# Flatten column names
mortality_by_sofa.columns = ['sofa_total', 'deaths', 'n_encounters', 'mortality_rate']
mortality_by_sofa['mortality_pct'] = mortality_by_sofa['mortality_rate'] * 100

# Sort by SOFA score
mortality_by_sofa = mortality_by_sofa.sort_values('sofa_total')

# Save CSV for this SOFA vs mortality data
mortality_by_sofa.to_csv('../output/final/sofa_mortality_data.csv', index=False)
print("✓ SOFA-mortality summary CSV saved to ../output/final/sofa_mortality_data.csv")

# Create figure with two subplots (main plot + table)
fig = plt.figure(figsize=(14, 8))
gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.05)
ax_main = fig.add_subplot(gs[0])
ax_table = fig.add_subplot(gs[1])

# Main plot - Mortality rate bars
bars = ax_main.bar(mortality_by_sofa['sofa_total'],
                    mortality_by_sofa['mortality_pct'],
                    color='steelblue',
                    edgecolor='black',
                    alpha=0.7,
                    width=0.8)

# Add value labels on top of bars
for i, (sofa, pct, n) in enumerate(zip(mortality_by_sofa['sofa_total'],
                                        mortality_by_sofa['mortality_pct'],
                                        mortality_by_sofa['n_encounters'])):
    ax_main.text(sofa, pct + 1, f'{pct:.1f}%',
                ha='center', va='bottom', fontsize=9, fontweight='bold')

# Main plot formatting
ax_main.set_ylabel('In-Hospital Mortality (%)', fontsize=12, fontweight='bold')
ax_main.set_title('In-Hospital Mortality by SOFA Total Score at CRRT Initiation',
                fontsize=14, fontweight='bold', pad=20)
ax_main.set_ylim(0, mortality_by_sofa['mortality_pct'].max() * 1.15)
ax_main.grid(axis='y', alpha=0.3, linestyle='--')
ax_main.set_xlim(mortality_by_sofa['sofa_total'].min() - 0.5,
                mortality_by_sofa['sofa_total'].max() + 0.5)

# Remove x-axis labels from main plot (will show in table)
ax_main.set_xticklabels([])
ax_main.set_xlabel('')

# Create table with encounter counts
ax_table.axis('tight')
ax_table.axis('off')

# Prepare table data
table_data = [
    [f'{int(score)}' for score in mortality_by_sofa['sofa_total']],
    [f'n={int(n)}' for n in mortality_by_sofa['n_encounters']]
]

# Create table
table = ax_table.table(cellText=table_data,
                        rowLabels=['SOFA Score', 'N Encounters'],
                        cellLoc='center',
                        loc='center',
                        colWidths=[1/len(mortality_by_sofa)] * len(mortality_by_sofa))

# Format table
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 2)

# Style table cells
for i in range(len(table_data)):
    for j in range(len(mortality_by_sofa)):
        cell = table[(i, j)]
        cell.set_facecolor('lightgray' if i == 0 else 'white')
        cell.set_text_props(weight='bold' if i == 0 else 'normal')

# Style row labels
for i in range(len(table_data)):
    cell = table[(i, -1)]
    cell.set_facecolor('lightsteelblue')
    cell.set_text_props(weight='bold', ha='right')

# Add overall statistics box
overall_mortality = outcomes_with_sofa['in_hosp_death'].mean() * 100
total_n = len(outcomes_with_sofa)
stats_text = f'Overall Mortality: {overall_mortality:.1f}%\nTotal N: {total_n:,}'

ax_main.text(0.02, 0.98, stats_text,
            transform=ax_main.transAxes,
            fontsize=11,
            verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

# Save figure
plt.savefig('../output/final/graphs/mortality_by_sofa_score.png', dpi=300, bbox_inches='tight')
print("✓ Mortality by SOFA score plot saved to output/final/graphs/mortality_by_sofa_score.png")

plt.show()

# Labs

In [None]:
from clifpy.clif_orchestrator import ClifOrchestrator

# Initialize ClifOrchestrator
clif = ClifOrchestrator(
    data_directory=config['tables_path'],
    filetype=config['file_type'],
    timezone=config['timezone']
)

In [None]:
#Labs
labs_required_columns = [
    'hospitalization_id',
    'lab_result_dttm',
    'lab_category',
    'lab_value',
    'lab_value_numeric'
]
labs_of_interest = ['po2_arterial','pco2_arterial', 'ph_arterial','ph_venous', 'bicarbonate','so2_arterial',
                    'sodium', 'potassium', 'chloride', 'calcium_total', 'magnesium', 'creatinine', 
                    'bun', 'glucose_serum', 'lactate', 'hemoglobin', 'phosphate']

# labs_of_interest = ['ph_arterial', 'lactate', 'bicarbonate', 'potassium']

print(f"\nLoading labs table...")
clif.load_table(
    'labs',
    columns=labs_required_columns,
    filters={
        'hospitalization_id': cohort_df['hospitalization_id'].unique().tolist(),
        'lab_category': labs_of_interest
    }
)
print(f"   Labs loaded: {len(clif.labs.df):,} rows")
print(f"   Unique lab categories: {clif.labs.df['lab_category'].nunique()}")
print(f"   Unique lab hospitalizations: {clif.labs.df['hospitalization_id'].nunique()}")

In [None]:
# ============================================================================
# Get Most Recent Labs Within 6h of CRRT Initiation
# ============================================================================
print("\n" + "=" * 80)
print("Processing Labs - Most Recent Within 6h of CRRT Initiation")
print("=" * 80)

# Get labs dataframe
labs_df = clif.labs.df.copy()

# Merge with CRRT initiation times to get the reference time
labs_with_crrt = labs_df.merge(
    crrt_analysis_df[['encounter_block', 'hospitalization_id', 'crrt_initiation_time']],
    on='hospitalization_id',
    how='inner'
)

print(f"   Labs after merging with CRRT cohort: {len(labs_with_crrt):,}")

# Filter for labs within 6 hours BEFORE CRRT initiation
labs_before = labs_with_crrt[
    (labs_with_crrt['lab_result_dttm'] <= labs_with_crrt['crrt_initiation_time']) &
    (labs_with_crrt['lab_result_dttm'] >= labs_with_crrt['crrt_initiation_time'] - pd.Timedelta(hours=12))
].copy()

# Filter for labs within 6 hours AFTER CRRT initiation
labs_after = labs_with_crrt[
    (labs_with_crrt['lab_result_dttm'] > labs_with_crrt['crrt_initiation_time']) &
    (labs_with_crrt['lab_result_dttm'] <= labs_with_crrt['crrt_initiation_time'] + pd.Timedelta(hours=3))
].copy()

print(f"   Labs within 6h before CRRT: {len(labs_before):,}")
print(f"   Labs within 6h after CRRT: {len(labs_after):,}")

# Strategy: Prioritize BEFORE, then AFTER
# For each encounter_block + lab_category:
#   1. If lab exists BEFORE initiation: take the MOST RECENT (closest to initiation)
#   2. If no lab BEFORE: take the EARLIEST lab AFTER initiation

# Get most recent lab BEFORE initiation per encounter_block + lab_category
labs_before_sorted = labs_before.sort_values(['encounter_block', 'lab_category', 'lab_result_dttm'])
labs_before_most_recent = (labs_before_sorted
                            .groupby(['encounter_block', 'lab_category'])
                            .last()  # Most recent = closest to initiation
                            .reset_index())
labs_before_most_recent['source'] = 'before'

# Get earliest lab AFTER initiation per encounter_block + lab_category
labs_after_sorted = labs_after.sort_values(['encounter_block', 'lab_category', 'lab_result_dttm'])
labs_after_earliest = (labs_after_sorted
                        .groupby(['encounter_block', 'lab_category'])
                        .first()  # Earliest = closest to initiation
                        .reset_index())
labs_after_earliest['source'] = 'after'

# Identify which encounter_block + lab_category combinations have BEFORE labs
before_keys = set(labs_before_most_recent[['encounter_block', 'lab_category']].apply(tuple, axis=1))

# Filter AFTER labs to only those WITHOUT a BEFORE lab
labs_after_filtered = labs_after_earliest[
    ~labs_after_earliest[['encounter_block', 'lab_category']].apply(tuple, axis=1).isin(before_keys)
]

# Combine: all BEFORE labs + AFTER labs (only where no BEFORE exists)
labs_final = pd.concat([labs_before_most_recent, labs_after_filtered], ignore_index=True)

print(f"\n   Final lab selection:")
print(f"     From BEFORE window: {(labs_final['source'] == 'before').sum():,}")
print(f"     From AFTER window: {(labs_final['source'] == 'after').sum():,}")
print(f"     Total unique encounter_block + lab combinations: {len(labs_final):,}")

# Pivot to wide format (one row per encounter_block, one column per lab)
labs_wide = labs_final.pivot(
    index='encounter_block',
    columns='lab_category',
    values='lab_value_numeric'
).reset_index()

# Add suffix to lab column names for clarity
labs_wide.columns = ['encounter_block'] + [f'{col}_peri_crrt' for col in labs_wide.columns if col != 'encounter_block']

print(f"\n   Labs in wide format:")
print(f"     Encounter blocks: {len(labs_wide):,}")
print(f"     Lab columns: {list(labs_wide.columns)}")

# Show availability by lab
print(f"\n   Lab value availability (non-null):")
for col in labs_wide.columns:
    if col != 'encounter_block':
        n_available = labs_wide[col].notna().sum()
        pct = n_available / len(labs_wide) * 100
        print(f"     {col}: {n_available:,} ({pct:.1f}%)")

# Merge back with CRRT analysis dataset
crrt_with_labs = crrt_analysis_df.merge(
    labs_wide,
    on='encounter_block',
    how='left'
)

print(f"\n   Final dataset with labs:")
print(f"     Total records: {len(crrt_with_labs):,}")
print(f"     Records with at least one lab: {(crrt_with_labs[[col for col in crrt_with_labs.columns if '_peri_crrt' in col]].notna().any(axis=1)).sum():,}")

print("\n✅ Lab processing completed successfully!")

In [None]:
# ============================================================================
# Identify Encounter Blocks Without Labs
# ============================================================================
print("\n" + "=" * 80)
print("Identifying Encounter Blocks Without Labs")
print("=" * 80)

# Get all lab columns
lab_columns = [col for col in crrt_with_labs.columns if '_peri_crrt' in col]

# Identify encounters with NO labs documented (all lab columns are null)
crrt_with_labs['has_any_lab'] = crrt_with_labs[lab_columns].notna().any(axis=1)
encounters_without_labs = crrt_with_labs[~crrt_with_labs['has_any_lab']].copy()

# Summary statistics
total_encounters = len(crrt_with_labs)
encounters_with_labs = crrt_with_labs['has_any_lab'].sum()
encounters_without_labs_count = len(encounters_without_labs)

print(f"\n   Lab Documentation Status:")
print(f"     Total encounter blocks: {total_encounters:,}")
print(f"     With at least one lab: {encounters_with_labs:,} ({encounters_with_labs/total_encounters*100:.1f}%)")
print(f"     WITHOUT any labs: {encounters_without_labs_count:,} ({encounters_without_labs_count/total_encounters*100:.1f}%)")

# Show breakdown by lab category availability
print(f"\n   Lab-specific availability:")
for lab_col in lab_columns:
    n_available = crrt_with_labs[lab_col].notna().sum()
    pct = n_available / total_encounters * 100
    lab_name = lab_col.replace('_peri_crrt', '')
    print(f"     {lab_name}: {n_available:,}/{total_encounters:,} ({pct:.1f}%)")

# Show characteristics of encounters without labs
if len(encounters_without_labs) > 0:
    print(f"\n   Characteristics of encounters WITHOUT labs:")
    print(f"     CRRT modes:")
    mode_dist = encounters_without_labs['crrt_mode_category'].value_counts()
    for mode, count in mode_dist.items():
        print(f"       {mode}: {count:,} ({count/len(encounters_without_labs)*100:.1f}%)")

    print(f"     Mortality:")
    print(f"       30-day deaths: {encounters_without_labs['death_30d_from_crrt'].sum():,} ({encounters_without_labs['death_30d_from_crrt'].mean()*100:.1f}%)")
    print(f"       In-hospital deaths: {encounters_without_labs['in_hosp_death'].sum():,} ({encounters_without_labs['in_hosp_death'].mean()*100:.1f}%)")

# Optional: Save list of encounter blocks without labs for QC
encounters_without_labs_list = encounters_without_labs[['encounter_block', 'hospitalization_id', 'crrt_initiation_time']].copy()

print(f"\n   Sample encounter blocks without labs (first 10):")
if len(encounters_without_labs_list) > 0:
    print(encounters_without_labs_list.head(10).to_string(index=False))
else:
    print("     None - all encounters have at least one lab!")

# Create analysis-ready dataset flag
# You may want to exclude encounters without labs from certain analyses
crrt_with_labs['analysis_ready'] = crrt_with_labs['has_any_lab']

print(f"\n   Analysis-ready encounters (with labs): {crrt_with_labs['analysis_ready'].sum():,}")
print(f"   Encounters to exclude (no labs): {(~crrt_with_labs['analysis_ready']).sum():,}")

print("\n✅ Lab completeness check completed!")

In [None]:
crrt_with_labs.to_parquet('../output/intermediate/crrt_analysis_with_labs.parquet')

# Competing Risk Dataset

Outcome coding:  

* 0 = Censored (>90 days or still hospitalized)
* 1 = Discharged alive (within 90 days)
* 2 = Died (within 90 days)

In [None]:
for df_name, df in [("outcomes_with_sofa", outcomes_with_sofa), 
                    ("crrt_analysis_df", crrt_analysis_df), 
                    ("crrt_with_labs", crrt_with_labs)]:
    print(f"\nDataFrame: {df_name}")
    print(f"Columns ({len(df.columns)}): {list(df.columns)}")



In [None]:
# ============================================================================
# Create Competing Risk Analysis Dataset
# ============================================================================
print("\n" + "=" * 80)
print("Creating Competing Risk Analysis Dataset")
print("=" * 80)

# ============================================================================
# STEP 1: Merge All Data Sources
# ============================================================================
print("\n1. Merging data sources...")

# Start with CRRT data (has crrt_initiation_time and CRRT parameters)
competing_risk_df = crrt_with_labs.copy()

print(f"   Starting with CRRT data: {len(competing_risk_df):,} encounter blocks")

# Merge demographics and outcomes from outcomes_with_sofa
# Select relevant columns (avoid duplicates)
outcomes_cols = [
    'encounter_block',
    'age_at_admission',
    'sex_category',
    'race_category',
    'ethnicity_category',
    'icu_los_days',
    'hosp_los_days',
    'duration_days', 'imv_duration_days',
    'admission_type_category',
    'discharge_category',
    'first_vital_dttm',
    'last_vital_dttm',
    'final_outcome_dttm', 
    'sofa_cv_97',
    'sofa_coag',
    'sofa_liver',
    'sofa_resp',
    'sofa_cns',
    'sofa_renal',
    'sofa_total'
]

# Check which columns exist in outcomes_with_sofa
available_cols = [col for col in outcomes_cols if col in outcomes_with_sofa.columns]

competing_risk_df = competing_risk_df.merge(
    outcomes_with_sofa[available_cols],
    on='encounter_block',
    how='left'
)

print(f"   After merging outcomes: {len(competing_risk_df):,} rows, {len(competing_risk_df.columns)} columns")

In [None]:
# ============================================================================
# STEP 2: Calculate Time-to-Event (Days from CRRT to Discharge/Death)
# ============================================================================
print("\n2. Calculating time-to-event...")

# For discharged alive: use last_vital_dttm as proxy for discharge time
# For died: use final_outcome_dttm (most accurate death timestamp)
competing_risk_df['discharge_dttm'] = competing_risk_df['last_vital_dttm']

# Calculate time to event in days
competing_risk_df['time_to_event_days'] = np.where(
    competing_risk_df['died'] == 1,
    # If died: time from CRRT to final_outcome_dttm
    (competing_risk_df['final_outcome_dttm'] - competing_risk_df['crrt_initiation_time']).dt.total_seconds() / (24 * 3600),
    # If alive: time from CRRT to discharge
    (competing_risk_df['discharge_dttm'] - competing_risk_df['crrt_initiation_time']).dt.total_seconds() / (24 * 3600)
)

# Handle negative or zero times
negative_times = (competing_risk_df['time_to_event_days'] < 0).sum()
if negative_times > 0:
    print(f"   ⚠️  {negative_times} records with negative time-to-event (setting to 0.5 days)")
    competing_risk_df.loc[competing_risk_df['time_to_event_days'] < 0, 'time_to_event_days'] = 0.5

zero_times = (competing_risk_df['time_to_event_days'] == 0).sum()
if zero_times > 0:
    print(f"   ⚠️  {zero_times} records with zero time-to-event (setting to 0.5 days)")
    competing_risk_df.loc[competing_risk_df['time_to_event_days'] == 0, 'time_to_event_days'] = 0.5

print(f"   Time-to-event range: {competing_risk_df['time_to_event_days'].min():.1f} - {competing_risk_df['time_to_event_days'].max():.1f} days")
print(f"   Median: {competing_risk_df['time_to_event_days'].median():.1f} days")

In [None]:
# ============================================================================
# STEP 3: Apply 90-Day Censoring and Calculate 90-Day Mortality
# ============================================================================
print("\n3. Applying 90-day censoring...")

# Cap at 90 days
competing_risk_df['time_to_event_90d'] = competing_risk_df['time_to_event_days'].clip(upper=90)

# Flag censoring
competing_risk_df['censored_at_90d'] = (competing_risk_df['time_to_event_days'] > 90).astype(int)

# Calculate 90-day mortality from CRRT initiation
competing_risk_df['death_90d_from_crrt'] = (
    (competing_risk_df['died'] == 1) &
    (competing_risk_df['time_to_event_days'] <= 90)
).astype(int)

events_censored = competing_risk_df['censored_at_90d'].sum()
deaths_90d = competing_risk_df['death_90d_from_crrt'].sum()

print(f"   Events censored at 90 days: {events_censored:,} ({events_censored/len(competing_risk_df)*100:.1f}%)")
print(f"   Deaths within 90 days of CRRT: {deaths_90d:,} ({deaths_90d/len(competing_risk_df)*100:.1f}%)")

# ============================================================================
# STEP 4: Create Competing Risk Outcome Variable
# ============================================================================
print("\n4. Creating competing risk outcome variable...")

# Outcome coding:
# 0 = Censored (>90 days or still hospitalized)
# 1 = Discharged alive (within 90 days)
# 2 = Died (within 90 days)

competing_risk_df['outcome'] = 0  # Default: censored

# Discharged alive within 90 days
competing_risk_df.loc[
    (competing_risk_df['death_90d_from_crrt'] == 0) &
    (competing_risk_df['time_to_event_days'] <= 90),
    'outcome'
] = 1

# Died within 90 days
competing_risk_df.loc[
    competing_risk_df['death_90d_from_crrt'] == 1,
    'outcome'
] = 2

# Outcome distribution
print(f"\n   Outcome distribution:")
outcome_labels = {0: 'Censored (>90d)', 1: 'Discharged alive', 2: 'Died'}
for outcome_val in [0, 1, 2]:
    count = (competing_risk_df['outcome'] == outcome_val).sum()
    pct = count / len(competing_risk_df) * 100
    print(f"     {outcome_labels[outcome_val]}: {count:,} ({pct:.1f}%)")

In [None]:
# ============================================================================
# STEP 5: Select Final Columns
# ============================================================================
print("\n5. Selecting final analysis columns...")

final_columns = [
    # Identifiers
    'encounter_block',

    # Time and outcome
    'crrt_initiation_time',
    'time_to_event_90d',
    'outcome',
    'censored_at_90d',

    # Demographics
    'age_at_admission',
    'sex_category',
    'race_category',
    'ethnicity_category',

    # CRRT parameters
    'crrt_mode_category',
    'crrt_dose_ml_kg_hr',
    'crrt_dose_ml_kg_hr_full',
    'dialysate_flow_rate',
    'pre_filter_replacement_fluid_rate',
    'post_filter_replacement_fluid_rate',
    'ultrafiltration_out',
    'total_flow_rate',
    'weight_kg',

    # Labs (peri-CRRT)
    'ph_arterial_peri_crrt',
    'lactate_peri_crrt',
    'bicarbonate_peri_crrt',
    'potassium_peri_crrt',
    'sodium_peri_crrt',
    'creatinine_peri_crrt',
    'bun_peri_crrt',
    'hemoglobin_peri_crrt',
    'glucose_serum_peri_crrt',
    'phosphate_peri_crrt',


    # SOFA scores
    'sofa_cv_97',
    'sofa_coag',
    'sofa_liver',
    'sofa_resp',
    'sofa_cns',
    'sofa_renal',
    'sofa_total',

    # LOS
    'icu_los_days',
    'hosp_los_days',

    # Quality flags
    'has_any_lab',
    'analysis_ready',

    #'Treatment Duration'
    'duration_days', 
    'imv_duration_days'
]

# Keep only columns that exist
available_final_cols = [col for col in final_columns if col in competing_risk_df.columns]
competing_risk_final = competing_risk_df[available_final_cols].copy()

print(f"   Final dataset: {len(competing_risk_final):,} rows × {len(competing_risk_final.columns)} columns")

In [None]:
# ============================================================================
# STEP 6: Data Validation
# ============================================================================
print("\n6. Data validation...")

# Check for impossible values
issues = []

# Time-to-event validation
if (competing_risk_final['time_to_event_90d'] < 0).any():
    issues.append("Negative time-to-event values")

# Age validation
if 'age_at_admission' in competing_risk_final.columns:
    invalid_age = ((competing_risk_final['age_at_admission'] < 18) |
                    (competing_risk_final['age_at_admission'] > 120)).sum()
    if invalid_age > 0:
        issues.append(f"{invalid_age} records with invalid age")

# Outcome validation
if competing_risk_final['outcome'].isna().any():
    issues.append("Missing outcome values")

if issues:
    print("   ⚠️  Data quality issues:")
    for issue in issues:
        print(f"     - {issue}")
else:
    print("   ✓ No data quality issues detected")

In [None]:
# ============================================================================
# STEP 7: Summary Statistics
# ============================================================================
print("\n" + "=" * 80)
print("Summary Statistics")
print("=" * 80)

print(f"\n   Cohort Size: {len(competing_risk_final):,} encounter blocks")

print(f"\n   Competing Risk Outcomes:")
for outcome_val in [0, 1, 2]:
    count = (competing_risk_final['outcome'] == outcome_val).sum()
    pct = count / len(competing_risk_final) * 100
    print(f"     {outcome_labels[outcome_val]}: {count:,} ({pct:.1f}%)")

print(f"\n   Time-to-Event (days):")
print(f"     Mean ± SD: {competing_risk_final['time_to_event_90d'].mean():.1f} ± {competing_risk_final['time_to_event_90d'].std():.1f}")
print(f"     Median [IQR]: {competing_risk_final['time_to_event_90d'].median():.1f} [{competing_risk_final['time_to_event_90d'].quantile(0.25):.1f}-{competing_risk_final['time_to_event_90d'].quantile(0.75):.1f}]")

if 'age_at_admission' in competing_risk_final.columns:
    print(f"\n   Age at Admission (years):")
    print(f"     Mean ± SD: {competing_risk_final['age_at_admission'].mean():.1f} ± {competing_risk_final['age_at_admission'].std():.1f}")

if 'crrt_dose_ml_kg_hr' in competing_risk_final.columns:
    dose_available = competing_risk_final['crrt_dose_ml_kg_hr'].notna()
    print(f"\n   CRRT Dose (mL/kg/hr):")
    print(f"     Available: {dose_available.sum():,}/{len(competing_risk_final):,} ({dose_available.mean()*100:.1f}%)")
    if dose_available.any():
        print(f"     Mean ± SD: {competing_risk_final['crrt_dose_ml_kg_hr'].mean():.1f} ± {competing_risk_final['crrt_dose_ml_kg_hr'].std():.1f}")

if 'has_any_lab' in competing_risk_final.columns:
    print(f"\n   Lab Availability:")
    with_labs = competing_risk_final['has_any_lab'].sum()
    print(f"     With labs: {with_labs:,} ({with_labs/len(competing_risk_final)*100:.1f}%)")

In [None]:
competing_risk_final.to_parquet("../output/intermediate/competing_risk_final.parquet")
competing_risk_final.columns

In [None]:
#Descriptive Statistics & Data Quality

# Overall summary
print(f"Total encounters: {len(competing_risk_final):,}")
print(f"Outcome distribution:")
print(competing_risk_final['outcome'].value_counts().sort_index())

# Check missingness for key variables
key_vars = ['crrt_dose_ml_kg_hr', 'crrt_dose_ml_kg_hr_full','age_at_admission', 'weight_kg',
            'crrt_mode_category', 'sofa_total']
print("\nMissingness in key variables:")
for var in key_vars:
    missing = competing_risk_final[var].isna().sum()
    print(f"  {var}: {missing} ({missing/len(competing_risk_final)*100:.1f}%)")

# Lab availability
print(f"\nAnalysis-ready (complete labs): {competing_risk_final['analysis_ready'].sum():,}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Distribution of time-to-event by outcome
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
competing_risk_final[competing_risk_final['outcome']==1]['time_to_event_90d'].hist(bins=30, ax=ax[0])
ax[0].set_title('Time to Discharge (Alive)')
competing_risk_final[competing_risk_final['outcome']==2]['time_to_event_90d'].hist(bins=30, ax=ax[1])
ax[1].set_title('Time to Death')

# TableOne

In [None]:
from utils import create_table_one_competing_risk, print_table_one_summary

# Full Table 1 stratified by outcome
table1 = create_table_one_competing_risk(
    competing_risk_final,
    stratify_by='outcome',
    output_path='../output/final/table1_by_outcome.csv'
)

# Display in notebook
display(table1)

# Quick summary to console
print_table_one_summary(competing_risk_final)

# Table 1 by CRRT mode instead
table1_mode = create_table_one_competing_risk(
    competing_risk_final,
    stratify_by='crrt_mode_category',
    output_path='../output/final/table1_by_crrt_mode.csv'
)

# Overall only (no stratification)
table1_overall = create_table_one_competing_risk(
    competing_risk_final,
    stratify_by=None,
    output_path='../output/final/table1_overall.csv'
)

In [None]:
import pandas as pd
# Read the data from parquet file
parquet_path = "/Users/kavenchhikara/Projects/CLIF/CLIF-epidemiology-of-CRRT/output/intermediate/competing_risk_final.parquet"
competing_risk_final_from_parquet = pd.read_parquet(parquet_path)
# Preview the loaded dataframe
# display(competing_risk_final_from_parquet)