In [None]:
# EquiFlow Testing Notebook

# Install package if needed (uncomment to run)
# !pip install equiflow

# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from equiflow import *  # Import all components from equiflow

# Set random seed for reproducibility
np.random.seed(42)

In [None]:
# PART 1: Generate Synthetic Data
# ===============================

print("Generating synthetic patient data...")

# Define sample size
n_patients = 1000

# Generate demographic variables
age = np.random.normal(65, 15, n_patients).round(1)
age = np.clip(age, 18, 95)  # Limit to realistic values

sex = np.random.choice(['Male', 'Female'], n_patients, p=[0.55, 0.45])

race = np.random.choice(
    ['White', 'Black', 'Hispanic', 'Asian', 'Other'],
    n_patients,
    p=[0.65, 0.15, 0.10, 0.05, 0.05]
)

# Clinical variables
bmi = np.random.normal(28, 6, n_patients).round(1)
bmi = np.clip(bmi, 15, 50)

systolic_bp = np.random.normal(130, 20, n_patients).round(0)
diastolic_bp = np.random.normal(80, 10, n_patients).round(0)

lab_value = np.random.lognormal(mean=1.5, sigma=0.7, size=n_patients).round(2)

# Hospital metrics
los_days = np.random.exponential(5, n_patients).round(0)  # Length of stay
los_days = np.clip(los_days, 1, 60)

readmission_count = np.random.poisson(0.5, n_patients)  # Previous readmissions

# Create some missing data patterns
# Random missingness
missing_mask_bmi = np.random.choice([True, False], n_patients, p=[0.05, 0.95])
bmi[missing_mask_bmi] = None

missing_mask_race = np.random.choice([True, False], n_patients, p=[0.03, 0.97])
race[missing_mask_race] = None

# Differential missingness (older patients more likely to have missing lab values)
missing_mask_lab = np.random.rand(n_patients) < (age / 200)  # Age-dependent probability of missing
lab_value[missing_mask_lab] = None

# Create the dataset
data = pd.DataFrame({
    'age': age,
    'sex': sex,
    'race': race,
    'bmi': bmi,
    'systolic_bp': systolic_bp,
    'diastolic_bp': diastolic_bp,
    'lab_value': lab_value,
    'los_days': los_days,
    'readmission_count': readmission_count
})

# Add a derived variable for hypertension
data['hypertension'] = ((data['systolic_bp'] >= 140) | (data['diastolic_bp'] >= 90)).map({True: 'Yes', False: 'No'})

# Display the first few rows
print("\nSample of synthetic patient data:")
print(data.head())

# Show basic statistics
print("\nData summary:")
print(data.describe())

# Check missingness
print("\nMissing values per column:")
print(data.isnull().sum())

In [None]:
# PART 2: Create Exclusion Criteria
# =================================

print("Defining exclusion criteria...")

# Exclusion 1: Remove patients with missing race data
complete_race = ~data['race'].isna()
print(f"Patients with complete race data: {complete_race.sum()} of {len(data)}")

# Exclusion 2: Remove patients with BMI > 35 (could affect demographic distribution)
normal_bmi = (data['bmi'] <= 35) | (data['bmi'].isna())
print(f"Patients with BMI ≤ 35 or missing BMI: {normal_bmi.sum()} of {len(data)}")

# Exclusion 3: Remove patients with abnormal lab values (> 10)
normal_labs = (data['lab_value'] <= 10) | (data['lab_value'].isna())
print(f"Patients with normal lab values or missing labs: {normal_labs.sum()} of {len(data)}")

# Exclusion 4: Remove patients with long hospital stays (> 14 days)
short_stay = (data['los_days'] <= 14) | (data['los_days'].isna())
print(f"Patients with stays ≤ 14 days or missing stay data: {short_stay.sum()} of {len(data)}")

In [None]:
# PART 3: Initialize and Use EquiFlow
# ===================================

print("Initializing EquiFlow...")

# Initialize EquiFlow with our dataset
flow = EquiFlow(
    data=data,
    initial_cohort_label="All patients",
    categorical=['sex', 'race', 'hypertension'],
    normal=['age', 'bmi', 'systolic_bp', 'diastolic_bp'],
    nonnormal=['lab_value', 'los_days', 'readmission_count'],
    decimals=1,
    format_cat='N (%)',
    format_normal='Mean ± SD',
    format_nonnormal='Median [IQR]',
    missingness=True,
    rename={
        'age': 'Age (years)',
        'sex': 'Sex',
        'race': 'Race/Ethnicity',
        'bmi': 'BMI (kg/m²)',
        'systolic_bp': 'Systolic BP (mmHg)',
        'diastolic_bp': 'Diastolic BP (mmHg)',
        'lab_value': 'Lab Test Result',
        'los_days': 'Length of Stay (days)',
        'readmission_count': 'Prior Readmissions',
        'hypertension': 'Hypertension'
    }
)

# Add exclusion steps
# NOTE: Using keep= parameter (not mask=, which is deprecated)
print("\nAdding exclusion steps...")

flow.add_exclusion(
    keep=complete_race,  # Fixed: was 'mask='
    exclusion_reason="Missing race/ethnicity data",
    new_cohort_label="Complete demographic data"
)

flow.add_exclusion(
    keep=normal_bmi,  # Fixed: was 'mask='
    exclusion_reason="BMI > 35 kg/m²",
    new_cohort_label="Normal BMI patients"
)

flow.add_exclusion(
    keep=normal_labs,  # Fixed: was 'mask='
    exclusion_reason="Abnormal lab values (> 10)",
    new_cohort_label="Normal lab values"
)

flow.add_exclusion(
    keep=short_stay,  # Fixed: was 'mask='
    exclusion_reason="Length of stay > 14 days",
    new_cohort_label="Standard inpatients"
)

In [None]:
# PART 4: Generate and Display Tables
# ==================================

print("Generating tables...")

# View flow table
print("\n===== COHORT FLOW TABLE =====")
flow_table = flow.view_table_flows()
print(flow_table)

# View characteristics table
print("\n===== COHORT CHARACTERISTICS TABLE =====")
characteristics_table = flow.view_table_characteristics()
print(characteristics_table)

# View drifts table
# Fixed: removed non-existent 'drifts_by_class' parameter
print("\n===== DISTRIBUTION DRIFTS TABLE =====")
drifts_table = flow.view_table_drifts()
print(drifts_table)

In [None]:
# PART 5: Generate Flow Diagram
# ============================

print("Generating flow diagram...")

# Generate the flow diagram
try:
    flow.plot_flows(
        output_folder="output",
        output_file="patient_selection_flow",
        box_width=3.5,
        box_height=1.2,
        plot_dists=True,
        smds=True,
        legend=True,
        display_flow_diagram=False
    )
    print("Flow diagram generated successfully: output/patient_selection_flow.pdf")
except Exception as e:
    print(f"Error generating flow diagram: {e}")

In [None]:
# PART 6: Analyze Equity Impact
# ============================

print("\n===== EQUITY IMPACT ANALYSIS =====")
print("Examining potential bias introduced by exclusion criteria:")

# Analyze standardized mean differences (SMDs)
smd_threshold = 0.1  # Common threshold for meaningful difference

for i in range(len(drifts_table.columns)):
    col = drifts_table.columns[i]
    
    # Convert to numeric, replacing non-numeric values with NaN
    numeric_values = pd.to_numeric(drifts_table[col], errors='coerce')
    
    # Filter out values above threshold (ignoring NaN values)
    high_smd_mask = numeric_values.abs() > smd_threshold
    high_smd_vars = drifts_table.index[high_smd_mask].tolist()
    
    if high_smd_vars:
        print(f"\nExclusion step {col} significantly affected distributions of:")
        for var in high_smd_vars:
            try:
                smd_value = float(drifts_table.loc[var, col])
                print(f"  - {var}: SMD = {smd_value:.3f}")
            except (ValueError, TypeError):
                print(f"  - {var}: SMD value could not be converted to float")

In [None]:
# PART 7: Additional EquiFlow Features
# ===================================

print("\n===== ADDITIONAL FEATURES =====")

# 1. View the drift table (standard output)
# Fixed: removed non-existent 'drifts_by_class' parameter
print("\nDrift table (SMDs between consecutive cohorts):")
drifts = flow.view_table_drifts()
print(drifts)

# 2. Using custom formatting options
print("\nCustomizing table characteristics display:")
custom_chars = flow.view_table_characteristics(
    format_cat='%',  # Show only percentages for categorical
    format_normal='Mean',  # Show only means for normal variables
    thousands_sep=False  # Don't use commas in numbers
)
print(custom_chars.iloc[:10])  # Show just first 10 rows

# 3. Creating a more minimalist flow diagram
print("\nGenerating a minimalist flow diagram (without distributions)...")

try:
    flow.plot_flows(
        output_folder="output",
        output_file="minimal_flow_diagram",
        plot_dists=False,  # Don't show distributions
        display_flow_diagram=False
    )
    print("Minimalist flow diagram generated: output/minimal_flow_diagram.pdf")
except Exception as e:
    print(f"Error generating minimalist diagram: {e}")

print("\nEquiFlow testing complete!")

In [None]:
# PART 8: P-Values with Multiple Testing Correction
# =================================================

print("\n===== P-VALUE ANALYSIS =====")

# View p-values without correction
print("\nP-values (no correction):")
pvals_none = flow.view_table_pvalues(correction="none")
print(pvals_none)

# View p-values with Bonferroni correction
print("\nP-values (Bonferroni correction):")
pvals_bonf = flow.view_table_pvalues(correction="bonferroni")
print(pvals_bonf)

# View p-values with FDR correction
print("\nP-values (FDR/Benjamini-Hochberg correction):")
pvals_fdr = flow.view_table_pvalues(correction="fdr_bh")
print(pvals_fdr)