In [None]:
import os
import math
import warnings
import datetime
import numpy as np
import pandas as pd
import seaborn as sns
import miceforest as mf
import matplotlib.pyplot as plt
import statsmodels.api as sm
import statsmodels.formula.api as smf

from scipy import stats
from patsy import dmatrices
from tableone import TableOne
from IPython.display import display
from dateutil.relativedelta import relativedelta
from statsmodels.discrete.count_model import ZeroInflatedNegativeBinomialP

In [None]:
### important dates
START_DATE      = pd.to_datetime('01-01-2019', format='%d-%m-%Y') # start of all datasets
EXTRACTION_DATE = pd.to_datetime('02-12-2022', format='%d-%m-%Y') # date of second (final) extraction from MB
END_DATE_SLAM   = pd.to_datetime('30-06-2022', format='%d-%m-%Y') # extended SLAM extraction
END_DATE_LDN    = pd.to_datetime('07-10-2021', format='%d-%m-%Y') # last consultation date, but spec says up to November 2021...

### modifiable period parameters (in months)
EXPOSURE_LENGTH = 6
OUTCOME_LENGTH  = 6
WINDOW_STEP     = 6
PERIODS_START   = 1 # 1 is first available month

### results directory
RESULTS_DIR = '../results/'
CLEANED_DATA_PATH = '../data/cleaned/'
if not os.path.exists(RESULTS_DIR): os.makedirs(RESULTS_DIR)

### Loading and preparing datasets
No saved output. Run this manually first.

In [None]:
df_demos = pd.read_csv(CLEANED_DATA_PATH + 'Patient_Level.csv', parse_dates=['yearofbirth', 'dateofdeath'])

In [None]:
df_demos_dynamic = pd.read_csv(CLEANED_DATA_PATH + 'Patient_Level_Dynamic.csv', parse_dates=['registrationstartdate', 'RegistrationEndDate'])

In [None]:
df_cons_ldn = pd.read_csv(CLEANED_DATA_PATH + 'LDN_Consultations.csv', parse_dates=['EffectiveDateTime'])

In [None]:
df_cons_slam = pd.read_csv(CLEANED_DATA_PATH + 'SLaM_Contacts.csv', parse_dates=['SLAM_event_date'])

In [None]:
df_adms_slam = pd.read_csv(CLEANED_DATA_PATH + 'SLaM_Admissions.csv', parse_dates=['Admission_Date', 'Discharge_Date'])

In [None]:
# filter to patients with at least one mental health diagnosis
# rare sexes are set to unknown

df_demos_copy = df_demos.copy()

df_demos_copy = df_demos_copy.rename(columns={
    'PseudonymisedNHSNumber':'nhs',
    # 'SLAM_anxiety_ever':'anx',
    # 'SLAM_depression_ever':'dep',
    # 'SLAM_SMI_ever':'smi',
    'LDN_N_anxiety_ever':'anx',
    'LDN_N_depression_ever':'dep',
    'LDN_N_SMI_ever':'smi',
    'LDN_N_AF_ever':'af',
    'LDN_N_heart_failure_ever':'hf',
    'LDN_N_IHD_ever':'ihd',
    'yearofbirth':'birth_year',
    'dateofdeath':'dod',
    'Sex':'sex',
    'NationalPracticeCode':'gp',
    'IMD_Decile':'imd',
    'Ethnicity':'ethnicity'})

df_demos_copy.sex = df_demos_copy.sex.replace({'I':pd.NA, 'U':pd.NA})
df_demos_copy.imd = df_demos_copy.imd.astype('Int32')
df_demos_copy.ethnicity = df_demos_copy.ethnicity.replace({'Missing':pd.NA})
df_demos_copy[['anx', 'dep', 'smi', 'af', 'hf', 'ihd']] = df_demos_copy[['anx', 'dep', 'smi', 'af', 'hf', 'ihd']].gt(0)

#### IMPORTANT
df_demos_copy = df_demos_copy.loc[df_demos_copy[['anx', 'dep', 'smi']].any(axis=1)] # filter to MH
# df_demos_copy = df_demos_copy.loc[df_demos_copy[['smi']].any(axis=1)] # filter to SMI

df_demos_copy = df_demos_copy[['nhs', 'anx', 'dep', 'smi', 'af', 'hf', 'ihd', 'birth_year', 'dod', 'sex', 'gp', 'imd', 'ethnicity']].reset_index(drop=True)
df_demos_copy = df_demos_copy.rename(columns={d:f"{d}_ever" for d in ['anx', 'dep', 'smi', 'af', 'hf', 'ihd']})
display(df_demos_copy.sample(2).T)
demos = df_demos_copy.copy()

In [None]:
df_demos_dynamic_copy = df_demos_dynamic.copy()
# rename columns
df_demos_dynamic_copy = df_demos_dynamic_copy.rename(columns={'RegistrationEndDate':'registration_end', 'registrationstartdate':'registration_start', 'NationalPracticeCode':'gp', 'IMD_Decile':'imd'})
### match patients in demographic dataset
df_demos_dynamic_copy = df_demos_dynamic_copy.loc[df_demos_dynamic_copy.nhs.isin(df_demos_copy.nhs)]
# format IMD
df_demos_dynamic_copy.imd = df_demos_dynamic_copy.imd.astype('Int32')
# show dataframe
display(df_demos_dynamic_copy.sample(3))
demos_dynamic = df_demos_dynamic_copy.copy()

In [None]:
# use only GP consultations, and de-duplicate same consultation IDs (always same modality)
# leaving in mulitple consultations per day (can have up to 24)

df_cons_ldn_copy = df_cons_ldn.loc[df_cons_ldn.userrolename.eq('GP')].copy()
df_cons_ldn_copy = df_cons_ldn_copy.rename(columns={'PseudonymisedNHSNumber':'nhs', 'EffectiveDateTime':'date', 'ConsultationTypeTerm':'modality', 'NationalPracticeCode':'gp', 'IMD_Decile':'imd', 'GranularModality':'modality_granular'})

### match patients in demographic dataset
df_cons_ldn_copy = df_cons_ldn_copy.loc[df_cons_ldn_copy.nhs.isin(df_demos_copy.nhs)]

df_cons_ldn_copy = df_cons_ldn_copy[['nhs', 'date', 'modality', 'gp', 'imd', 'modality_granular']].reset_index(drop=True)
df_cons_ldn_copy.imd = df_cons_ldn_copy.imd.astype('Int32')
df_cons_ldn_copy.date = df_cons_ldn_copy.date.dt.normalize()
df_cons_ldn_copy.modality = df_cons_ldn_copy.modality.replace({'Missing':pd.NA}) 
df_cons_ldn_copy = df_cons_ldn_copy.drop_duplicates()
granular_modalities = df_cons_ldn_copy[['nhs', 'date', 'modality', 'modality_granular']].copy() # save for descriptive stats
df_cons_ldn_copy = df_cons_ldn_copy.drop('modality_granular', axis=1)
display(df_cons_ldn_copy.sample(3))
consults = df_cons_ldn_copy.copy()

In [None]:
# attended emergency contacts only (liaison)

df_cons_slam_copy = df_cons_slam.loc[df_cons_slam.team_type.eq('emergency') & df_cons_slam.dimension_2_outcome.eq('attended')].copy()
df_cons_slam_copy.loc[df_cons_slam_copy.dimension_1_medium.eq('f2f') & df_cons_slam_copy.event_type_of_contact.eq('na'), 'Modality'] = 'F2F'
df_cons_slam_copy.loc[df_cons_slam_copy.dimension_1_medium.eq('Indirect_clinical') & df_cons_slam_copy.event_type_of_contact.eq('na'), 'Modality'] = 'Remote (Unknown)'
df_cons_slam_copy = df_cons_slam_copy.rename(columns={'PseudonymisedNHSNumber':'nhs', 'SLAM_event_date':'date', 'Modality':'modality', 'location_name':'location'})

### inclusion criteria of cardiac/mental condition not applied - so remove extra patients not in demographic data
df_cons_slam_copy = df_cons_slam_copy.loc[df_cons_slam_copy.nhs.isin(df_demos_copy.nhs)]

df_cons_slam_copy = df_cons_slam_copy[['nhs', 'date', 'modality', 'location']].reset_index(drop=True)
display(df_cons_slam_copy.sample(3))
emergencies = df_cons_slam_copy.copy()

In [None]:
# length-of-stay is calculated to extraction date if still admitted

df_adms_slam_copy = df_adms_slam.copy()
df_adms_slam_copy['mha'] = df_adms_slam_copy.MHA_Admission_Status_ID.isin(['Admitted On Section', 'Pre-existing Section', 'Formally detained under MHA Sec 3'])
df_adms_slam_copy = df_adms_slam_copy.rename(columns={'PseudonymisedNHSNumber':'nhs', 'Admission_Date':'admission_date', 'Discharge_Date':'discharge_date'})

### inclusion criteria of cardiac/mental condition not applied - so remove extra patients not in demographic data
df_adms_slam_copy = df_adms_slam_copy.loc[df_adms_slam_copy.nhs.isin(df_demos_copy.nhs)]

df_adms_slam_copy = df_adms_slam_copy[['nhs', 'mha', 'admission_date', 'discharge_date']].reset_index(drop=True)
# # df_adms_slam_copy['los'] = df_adms_slam_copy.discharge_date.fillna(EXTRACTION_DATE).sub(df_adms_slam_copy.admission_date).dt.days
display(df_adms_slam_copy.sample(3))
admissions = df_adms_slam_copy.copy()

### Summary of initial cohort
File sub-directory: ```descriptive/```
- Saves text to: ```summary_statistics.txt```

In [None]:
def save_summary_statistics(verbose=True, results_dir=None, ldn_dates=None, slam_dates=None):
    """Write summary statistics to file. If verbose=True, also print to console.
    
    Args:
        verbose: If True, print to console as well as file
        results_dir: Directory to save output file
        ldn_dates: Tuple of (start_date, end_date) to filter LDN data. If None, uses full data.
        slam_dates: Tuple of (start_date, end_date) to filter SLAM data. If None, uses full data.
    """
    
    # Set date ranges
    ldn_start, ldn_end = ldn_dates if ldn_dates is not None else (START_DATE, END_DATE_LDN)
    slam_start, slam_end = slam_dates if slam_dates is not None else (START_DATE, END_DATE_SLAM)
    
    # Find patients with active GP registration during LDN period
    # A registration overlaps if: registration_start <= ldn_end AND (registration_end >= ldn_start OR registration_end is NaT)
    active_registrations = demos_dynamic[
        (demos_dynamic.registration_start <= ldn_end) &
        ((demos_dynamic.registration_end >= ldn_start) | (demos_dynamic.registration_end.isna()))
    ]
    active_patients = active_registrations.nhs.unique()
    
    # Filter LDN dataframes by date and patient subset
    demos_filtered = demos[demos.nhs.isin(active_patients)].copy()
    consults_filtered = consults[
        (consults.date.between(ldn_start, ldn_end)) &
        (consults.nhs.isin(active_patients))
    ].copy()
    granular_modalities_filtered = granular_modalities[
        (granular_modalities.date.between(ldn_start, ldn_end)) &
        (granular_modalities.nhs.isin(active_patients))
    ].copy()
    
    # Filter SLAM dataframes by date and patient subset
    emergencies_filtered = emergencies[
        (emergencies.date.between(slam_start, slam_end)) &
        (emergencies.nhs.isin(active_patients))
    ].copy()
    
    # Filter admissions by overlap: active at any point during SLAM period
    # admission_date <= slam_end AND (discharge_date >= slam_start OR discharge_date is NaT)
    admissions_filtered = admissions[
        (admissions.admission_date <= slam_end) &
        ((admissions.discharge_date >= slam_start) | (admissions.discharge_date.isna())) &
        (admissions.nhs.isin(active_patients))
    ].copy()
    
    # Count admissions that started before the period (ongoing)
    ongoing_admissions = (admissions_filtered.admission_date < slam_start).sum()
    ongoing_mha = ((admissions_filtered.admission_date < slam_start) & admissions_filtered.mha).sum()
    
    # Open file for writing
    os.makedirs(results_dir, exist_ok=True)
    output_file_path = os.path.join(results_dir, 'summary_statistics.txt')
    
    def tee_print(*args, file=None, **kwargs):
        """Print to both console and file"""
        if verbose:
            print(*args, **kwargs)  # Print to console
        if file:
            print(*args, file=file, **kwargs)  # Print to file
    
    with open(output_file_path, 'w') as f:
        granular_lookup = {
            # f2f
            'Nursing home visit note':'Nursing home visit',
            # remote
            'Telephone consultation':'Telephone',
            'Telephone triage encounter':'Telephone',
            'Other consultation medium used':np.nan,
            'Telephone call to a patient':'Telephone',
            'Consultation via SMS text message':'Text',
            'Telephone call from a patient':'Telephone',
            'E-mail sent to patient':'Email',
            'Telephone call to relative/carer':'Telephone',
            'Telephone call from relative/carer':'Telephone',
            'E-mail received from patient':'Email',
            'Consultation via video conference':'Video',
            'E-mail consultation':'Email',
            'E-mail received from carer':'Email',
            'Consultation via multimedia':np.nan,
            'E-mail sent to carer':'Email',
        }
        
        tee_print(f"Overall statistics, pre-filtering and before period creation\n{'-' * 60}", end='\n\n', file=f)
        
        tee_print(f"LDN DATA (From {pd.Timestamp(ldn_start).date()} to {pd.Timestamp(ldn_end).date()})", file=f)
        tee_print(f"Total patients with active registration during period: {len(active_patients)} ({consults_filtered.nhs.nunique()} with consultations)", file=f)
        tee_print(f"Total GP count: {consults_filtered.gp.nunique()}", file=f)
        tee_print(f"Total GP consultation count: {consults_filtered.shape[0]}", file=f)
        
        gm = granular_modalities_filtered.copy()
        gm.modality_granular = gm.modality_granular.replace(granular_lookup)
        remote_granular = gm.loc[gm.modality.isin(['Telephone', 'Video/Email/Text'])].modality_granular.rename('Top 3 Remote modalities').value_counts().head(4).to_string()
        f2f_granular = gm.loc[gm.modality.eq('F2F')].modality_granular.rename('Top 3 F2F modalities').value_counts().head(3).to_string()
        
        tee_print(f"Total GP consultation count, by modality:", file=f)
        tee_print(f"F2F: {consults_filtered.modality.eq('F2F').sum()}", file=f)
        tee_print(f"Remote: {consults_filtered.modality.isin(['Telephone', 'Video/Email/Text']).sum()}", file=f)
        tee_print(f"Missing: {consults_filtered.modality.isna().sum()}\n", file=f)
        
        tee_print(f2f_granular, end='\n\n', file=f)
        tee_print(remote_granular, file=f)
        
        bed_days = admissions_filtered[['admission_date', 'discharge_date']].fillna(slam_end).clip(lower=slam_start, upper=slam_end).diff(axis=1).iloc[:, -1].dt.days.sum()
        
        tee_print(f"\nSLAM DATA (From {pd.Timestamp(slam_start).date()} to {pd.Timestamp(slam_end).date()}) - for above patient cohort only", file=f)
        tee_print(f"Total emergency liaison contact count: {emergencies_filtered.shape[0]}", file=f)
        tee_print(f"Total psychiatric admission count: {admissions_filtered.shape[0]} ({ongoing_admissions} ongoing from before period)", file=f)
        tee_print(f"Total psychiatric MHA-sectioned admission count: {admissions_filtered.mha.sum()} ({ongoing_mha} ongoing from before period)", file=f)
        tee_print(f"Total psychiatric bed-days (in period) count: {bed_days}", file=f)
    
    if verbose:
        print(f"\nSummary statistics saved to: {output_file_path}")

# # With date filtering
# save_summary_statistics(
#     verbose=True,
#     results_dir=f"{RESULTS_DIR}/descriptive",
#     ldn_dates=(START_DATE, END_DATE_LDN),
#     slam_dates=(START_DATE, END_DATE_SLAM)
# )

### Define periods
File sub-directory: ```period_definition/```
- Saves table to: ```periods.csv```
- Saves text to: ```timeline_visualization.txt```
- Saves figure to: ```timeline_visualization.png```

In [None]:
def generate_and_visualize_periods(
    start_date, 
    end_date_ldn, 
    end_date_slam, 
    exposure_length, 
    outcome_length, 
    window_step, 
    periods_start=1,
    visualize=True,
    silence=False,
    save_png=True,  # New parameter
    results_dir=RESULTS_DIR
):
    import matplotlib.pyplot as plt
    import matplotlib.font_manager as fm

    assert exposure_length >= 1
    assert outcome_length >= 1
    assert window_step >= 1
    assert periods_start >= 1
    
    # --- Core Logic ---
    timeline_start_month = start_date.to_period('M')
    total_months = (end_date_slam.to_period('M') - timeline_start_month).n + 1
    char_per_month = 3
    total_width = 6 + total_months * char_per_month

    # Generate Periods
    periods = []
    
    months_to_shift = periods_start - 1
    current_start = start_date + relativedelta(months=months_to_shift)
    
    period_id = 1
    
    while True:
        exp_end = current_start + relativedelta(months=exposure_length) - pd.Timedelta(days=1)
        if exp_end > end_date_ldn: break
        out_start = exp_end + pd.Timedelta(days=1)
        out_end = out_start + relativedelta(months=outcome_length) - pd.Timedelta(days=1)
        if out_end > end_date_slam: break
        
        periods.append({
            'period_id': period_id, 
            'exposure_start': current_start, 
            'exposure_end': exp_end,
            'outcome_start': out_start,
            'outcome_end': out_end
        })
        current_start += relativedelta(months=window_step)
        period_id += 1
        
    periods_df = pd.DataFrame(periods)

    # --- Create output directory ---
    os.makedirs(results_dir, exist_ok=True)
    
    # --- Save DataFrame ---
    csv_path = os.path.join(results_dir, 'periods.csv')
    periods_df.to_csv(csv_path, index=False)
    
    # --- Generate detailed visualization (always for file) ---
    viz_lines = []
    
    if periods_df.empty:
        summary_line = f"No periods generated. (Exp: {exposure_length}m, Out: {outcome_length}m, Start Month: {periods_start})"
        viz_lines.append(summary_line)
    else:
        # Generate detailed visualization
        viz_data = []
        for p in periods:
            line_parts = ['   '] * total_months
            p_id = p['period_id']
            idx_s_exp = (p['exposure_start'].to_period('M') - timeline_start_month).n
            idx_e_exp = ((p['exposure_start'] + relativedelta(months=exposure_length) - pd.Timedelta(days=1)).to_period('M') - timeline_start_month).n
            idx_s_out = ((p['exposure_start'] + relativedelta(months=exposure_length)).to_period('M') - timeline_start_month).n
            idx_e_out = (p['outcome_end'].to_period('M') - timeline_start_month).n
            
            if idx_e_exp >= total_months: idx_e_exp = total_months - 1
            if idx_e_out >= total_months: idx_e_out = total_months - 1

            for i in range(idx_s_exp, idx_e_exp + 1): line_parts[i] = 'e--'
            for i in range(idx_s_out, idx_e_out + 1): line_parts[i] = 'o--'
            viz_data.append(f"P{p_id}|{''.join(line_parts)}|P{p_id}")

        end_ldn_month = end_date_ldn.to_period('M')
        end_ldn_idx = (end_ldn_month - timeline_start_month).n
        end_ldn_col = 3 + end_ldn_idx * char_per_month

        def insert_vline(s, pos):
            if 0 <= pos < len(s):
                return s[:pos] + '|' + s[pos+1:]
            return s

        # Build detailed visualization (always for file)
        viz_lines.append("=" * total_width)
        viz_lines.append(f"Detailed Month-by-Month Timeline ({total_months} Months):")
        viz_lines.append("=" * total_width)

        header_line, month_line = "   ", ""
        current_year, year_start_idx = timeline_start_month.year, 0
        for i in range(total_months):
            m_p = timeline_start_month + i
            if m_p.year > current_year:
                header_line += f"{current_year:<{(i - year_start_idx) * char_per_month}}"
                current_year = m_p.year
                year_start_idx = i
            month_line += m_p.strftime('%b')
            
        header_line += f"{current_year:<{(total_months - year_start_idx) * char_per_month}}   "

        viz_lines.append(header_line)
        viz_lines.append("-" * total_width)
        viz_lines.append(insert_vline(f"   {month_line}   ", end_ldn_col))
        viz_lines.append(insert_vline("-" * total_width, end_ldn_col))
        for line in viz_data: 
            viz_lines.append(insert_vline(line, end_ldn_col))
        viz_lines.append(insert_vline("=" * total_width, end_ldn_col))
        
        bottom_line = " " * total_width
        viz_lines.append(insert_vline(bottom_line, end_ldn_col).replace(" " * end_ldn_col + "|", "LDN END".rjust(end_ldn_col) + "|"))

        # Summary line for console output
        start_str = periods_df.iloc[0]['exposure_start'].strftime('%Y-%m-%d')
        end_str = periods_df.iloc[-1]['outcome_end'].strftime('%Y-%m-%d')
        count = len(periods_df)
        summary_line = f"Generated {count} periods from {start_str} to {end_str} (Exp: {exposure_length}m, Out: {outcome_length}m, Step: {window_step}m, Start Month: {periods_start})"
    
    # --- Console output based on visualize flag ---
    if not silence:
        if visualize:
            # Print detailed visualization
            for line in viz_lines:
                print(line)
        else:
            # Print only summary
            print(summary_line)
    
    # --- Save visualization to text file (always detailed) ---
    viz_path = os.path.join(results_dir, 'timeline_visualization.txt')
    with open(viz_path, 'w') as f:
        f.write('\n'.join(viz_lines))
    
    # --- Save as PNG ---
    if save_png and not periods_df.empty:
        png_path = os.path.join(results_dir, 'timeline_visualization.png')
        
        # Create figure
        fig, ax = plt.subplots(figsize=(20, len(viz_lines) * 0.3 + 1))
        ax.axis('off')
        
        # Use monospace font
        full_text = '\n'.join(viz_lines)
        ax.text(0.02, 0.98, full_text, 
                fontfamily='monospace',
                fontsize=10,
                verticalalignment='top',
                horizontalalignment='left',
                transform=ax.transAxes)
        
        plt.tight_layout()
        plt.savefig(png_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        if not silence:
            print(f"PNG saved to: {png_path}")

    return periods_df  
    
# # Example Call
# periods_df = generate_and_visualize_periods(
#     START_DATE, 
#     END_DATE_LDN, 
#     END_DATE_SLAM, 
#     EXPOSURE_LENGTH, 
#     OUTCOME_LENGTH, 
#     WINDOW_STEP,
#     PERIODS_START, # 1 means start at the beginning
#     visualize=False,
#     save_png=True,  # New parameter
#     results_dir=f"{RESULTS_DIR}/period_definition/"
# )

### Populate periods with data
No saved output

In [None]:
def generate_cohort(
    demos, 
    consults, 
    emergencies, 
    admissions, 
    demos_dynamic, 
    periods_df,
    extraction_date,
    use_single_random_period=True,
    sample_after_eligibility=True,
    random_seed=2025,
    verbose=True,
    long_stay_threshold_days=7
):
    """
    Executes the full cohort extraction pipeline.
    
    Args:
        demos, consults, emergencies, admissions, demos_dynamic, periods_df: Input DataFrames.
        extraction_date: The date used to cap open-ended records (e.g., current date).
        use_single_random_period (bool): If True, targets 1 period per patient. If False, keeps all valid periods.
        sample_after_eligibility (bool): If True, checks all periods first, then samples 1. 
                                         (Only applies if use_single_random_period=True).
        random_seed (int): Seed for reproducibility.
        verbose (bool): If True, prints detailed step logs. If False, prints 1 summary line.
        long_stay_threshold_days (int): Minimum length of stay (in days) to count as a "long stay" admission.
        
    Returns:
        pd.DataFrame: The final analysis cohort.
    """
    
    # Helper for conditional printing
    def log(msg):
        if verbose:
            print(msg)

    # ==============================================================================
    # STEP 1: DATA INTEGRITY & MASTER LIST
    # ==============================================================================
    log("--- Step 1: Data Integrity Checks ---")

    # 1. Establish Master List from Demos
    required_demo_cols = ['nhs', 'birth_year', 'dod', 'sex', 'ethnicity']
    if not set(required_demo_cols).issubset(demos.columns):
        raise ValueError(f"Demos dataframe missing required columns: {set(required_demo_cols) - set(demos.columns)}")

    master_patients = demos[required_demo_cols].copy()
    valid_nhs_set = set(master_patients['nhs'])
    log(f"Master Patient Index established. Total Patients: {len(master_patients):,}")

    # 2. Validate Foreign Keys
    def validate_ids(df, name):
        orphan_mask = ~df['nhs'].isin(valid_nhs_set)
        orphan_count = orphan_mask.sum()
        if orphan_count > 0:
            log(f"WARNING: Found {orphan_count:,} records in '{name}' with NHS numbers NOT in demos. These will be ignored.")
            return df[~orphan_mask].copy()
        log(f" - '{name}' integrity check passed.")
        return df

    consults = validate_ids(consults, 'consults')
    emergencies = validate_ids(emergencies, 'emergencies')
    admissions = validate_ids(admissions, 'admissions')
    demos_dynamic = validate_ids(demos_dynamic, 'demos_dynamic')


    # ==============================================================================
    # STEP 2: COHORT CANDIDATE GENERATION
    # ==============================================================================
    log("\n--- Step 2: Period Assignment ---")

    if use_single_random_period and not sample_after_eligibility:
        log("Mode: Single Random Period Assignment (Pre-Check / Random Assign First)")
        np.random.seed(random_seed)
        
        master_patients['period_id'] = np.random.choice(periods_df['period_id'].unique(), size=len(master_patients))
        candidates_df = pd.merge(master_patients, periods_df, on='period_id', how='left')
        
    else:
        if use_single_random_period and sample_after_eligibility:
            log("Mode: Single Random Period Assignment (Post-Check / Maximize Retention)")
            log(" -> Generating full patient-period panel first...")
        else:
            log("Mode: Full Panel (Multi-Period)")
            
        # Cartesian Product
        candidates_df = pd.merge(
            master_patients.assign(key=1), 
            periods_df.assign(key=1), 
            on='key'
        ).drop('key', axis=1)

    log(f"Generated {len(candidates_df):,} candidate patient-period pairs.")


    # ==============================================================================
    # STEP 3: ELIGIBILITY & ACTIVE DAYS CALCULATION
    # ==============================================================================
    log("\n--- Step 3: Eligibility Filtering & Active Days ---")

    # 1. Death Check
    dead_mask = (candidates_df['dod'] <= candidates_df['outcome_start'])
    n_dead = dead_mask.sum()
    candidates_alive = candidates_df[~dead_mask].copy()

    # 2. Age Check (18+)
    # Calculate age exactly as it was done in Step 6, but do it here now.
    candidates_alive['age_at_exposure'] = (candidates_alive['exposure_start'] - candidates_alive['birth_year']).dt.days / 365.25
    
    under_18_mask = candidates_alive['age_at_exposure'] < 18
    n_under_18 = under_18_mask.sum()
    
    # Filter to adults only
    candidates_adults = candidates_alive[~under_18_mask].copy()

    # 3. Registration Expansion
    candidates_expanded = pd.merge(
        candidates_adults,
        demos_dynamic[['nhs', 'registration_start', 'registration_end']],
        on='nhs',
        how='left'
    )
    candidates_expanded['registration_end'] = candidates_expanded['registration_end'].fillna(extraction_date)

    # A. Calculate Active Days
    candidates_expanded['clip_start'] = candidates_expanded[['registration_start', 'exposure_start']].max(axis=1)
    candidates_expanded['clip_end'] = candidates_expanded[['registration_end', 'exposure_end']].min(axis=1)
    candidates_expanded['segment_days'] = (candidates_expanded['clip_end'] - candidates_expanded['clip_start']).dt.days
    candidates_expanded['segment_days'] = candidates_expanded['segment_days'].clip(lower=0)

    active_days_agg = candidates_expanded.groupby(['nhs', 'period_id'])['segment_days'].sum().reset_index(name='active_days_in_exposure')

    # B. Determine Eligibility Flags
    candidates_expanded['is_active'] = (
        (candidates_expanded['registration_start'] < candidates_expanded['exposure_end']) & 
        (candidates_expanded['registration_end'] > candidates_expanded['exposure_start'])
    )
    candidates_expanded['is_partial'] = (
        candidates_expanded['is_active'] & 
        (candidates_expanded['registration_start'] > candidates_expanded['exposure_start'])
    )

    validity_aggs = candidates_expanded.groupby(['nhs', 'period_id']).agg(
        is_active=('is_active', 'any'),
        is_partial=('is_partial', 'any')
    ).reset_index()

    candidates_checked = pd.merge(candidates_adults, validity_aggs, on=['nhs', 'period_id'], how='left')
    candidates_checked = pd.merge(candidates_checked, active_days_agg, on=['nhs', 'period_id'], how='left')
    
    # 2. FILL NA TO ENSURE BOOLEANS (Safety net)
    candidates_checked['is_active'] = candidates_checked['is_active'].fillna(False).astype(bool)
    candidates_checked['is_partial'] = candidates_checked['is_partial'].fillna(False).astype(bool)
    candidates_checked['active_days_in_exposure'] = candidates_checked['active_days_in_exposure'].fillna(0).astype(int)

    # Stats
    n_total = len(candidates_df)
    n_unreg = len(candidates_checked[~candidates_checked['is_active']])
    n_valid = len(candidates_checked[candidates_checked['is_active']])
    n_partial = len(candidates_checked[candidates_checked['is_active'] & candidates_checked['is_partial']])

    log(f"Eligibility Report:")
    log(f" - Total Candidates:                            {n_total:,}")
    log(f" - Excluded (died on/before outcome start):     {n_dead:,}")
    log(f" - Excluded (under 18 at exposure start):       {n_under_18:,}")
    log(f" - Excluded (not actively registered):          {n_unreg:,}")
    log(f" - Final Eligible Candidates:                   {n_valid:,}")
    log(f" ---> Partially registered during exposure (noiser exposure signal):    {n_partial:,}")

    eligible_cohort = candidates_checked[candidates_checked['is_active']].copy()
    eligible_cohort.drop(['is_active', 'is_partial'], axis=1, inplace=True)


    # ==============================================================================
    # STEP 4: EXPOSURE CALCULATION
    # ==============================================================================
    log("\n--- Step 4: Calculating Exposures ---")

    cohort_consults = pd.merge(eligible_cohort, consults, on='nhs', how='left')
    valid_consults_mask = (
        (cohort_consults['date'] >= cohort_consults['exposure_start']) &
        (cohort_consults['date'] < cohort_consults['exposure_end'])
    )
    cohort_consults = cohort_consults[valid_consults_mask].copy()

    remote_mods = ['Telephone', 'Video/Email/Text']
    cohort_consults['is_remote'] = cohort_consults['modality'].isin(remote_mods).astype(int)
    cohort_consults['is_missing'] = (
        cohort_consults['modality'].isna() | 
        (cohort_consults['modality'].astype(str).str.lower() == 'unknown')
    ).astype(int)
    cohort_consults['is_f2f'] = (
        (cohort_consults['is_remote'] == 0) & 
        (cohort_consults['is_missing'] == 0)
    ).astype(int)

    exposure_metrics = cohort_consults.groupby(['nhs', 'period_id']).agg(
        total_consults=('date', 'size'),
        remote_consults=('is_remote', 'sum'),
        f2f_consults=('is_f2f', 'sum'),
        missing_modality_consults=('is_missing', 'sum')
    ).reset_index()

    cohort_with_exposure = pd.merge(eligible_cohort, exposure_metrics, on=['nhs', 'period_id'], how='left')
    fill_cols = ['total_consults', 'remote_consults', 'f2f_consults', 'missing_modality_consults']
    cohort_with_exposure[fill_cols] = cohort_with_exposure[fill_cols].fillna(0).astype(int)

    n_before_drop = len(cohort_with_exposure)
    cohort_with_exposure = cohort_with_exposure[cohort_with_exposure['total_consults'] > 0].copy()
    log(f"Dropped {n_before_drop - len(cohort_with_exposure):,} patient-periods pairs with 0 consults.")

    # --- NEW SAMPLING LOGIC ---
    if use_single_random_period and sample_after_eligibility:
        log(f"\n[Sampling] Reducing to 1 period per patient (Post-Check Mode)")
        log(f" - Candidates before sampling: {len(cohort_with_exposure):,} rows")
        
        cohort_with_exposure = cohort_with_exposure.groupby('nhs').sample(n=1, random_state=random_seed)
        
        log(f" - Final Unique Patients:      {len(cohort_with_exposure):,} rows")


    # ==============================================================================
    # STEP 5: OUTCOME CALCULATION
    # ==============================================================================
    log("\n--- Step 5: Calculating Outcomes ---")

    # A. Emergencies
    cohort_emerg = pd.merge(cohort_with_exposure, emergencies, on='nhs', how='left')
    valid_emerg = cohort_emerg[
        (cohort_emerg['date'] >= cohort_emerg['outcome_start']) &
        (cohort_emerg['date'] < cohort_emerg['outcome_end'])
    ]
    emerg_counts = valid_emerg.groupby(['nhs', 'period_id']).size().reset_index(name='outcome_emergencies')

    # B. Admissions
    cohort_adm = pd.merge(cohort_with_exposure, admissions, on='nhs', how='left')
    valid_adm = cohort_adm[
        (cohort_adm['admission_date'] >= cohort_adm['outcome_start']) &
        (cohort_adm['admission_date'] < cohort_adm['outcome_end'])
    ]
    adm_counts = valid_adm.groupby(['nhs', 'period_id']).size().reset_index(name='outcome_admissions')

    # C. MHA
    valid_mha = valid_adm[valid_adm['mha'] == True]
    mha_counts = valid_mha.groupby(['nhs', 'period_id']).size().reset_index(name='outcome_mha_admissions')

    # D. Bed Days
    cohort_bed = pd.merge(cohort_with_exposure, admissions, on='nhs', how='left')
    relevant_adm = cohort_bed[
        (cohort_bed['admission_date'] < cohort_bed['outcome_end']) &
        (cohort_bed['discharge_date'].fillna(extraction_date) >= cohort_bed['outcome_start'])
    ].copy()
    relevant_adm['clip_start'] = np.maximum(relevant_adm['admission_date'], relevant_adm['outcome_start'])
    relevant_adm['clip_end'] = np.minimum(relevant_adm['discharge_date'].fillna(extraction_date), relevant_adm['outcome_end'])
    relevant_adm['bed_days'] = (relevant_adm['clip_end'] - relevant_adm['clip_start']).dt.days
    bed_day_sums = relevant_adm.groupby(['nhs', 'period_id'])['bed_days'].sum().reset_index(name='outcome_bed_days')

    # E. Long Stay Admissions (admissions > N days)
    valid_adm_los = valid_adm.copy()
    valid_adm_los['length_of_stay'] = (
        valid_adm_los['discharge_date'].fillna(extraction_date) - valid_adm_los['admission_date']
    ).dt.days
    long_stay_adm = valid_adm_los[valid_adm_los['length_of_stay'] > long_stay_threshold_days]
    long_stay_counts = long_stay_adm.groupby(['nhs', 'period_id']).size().reset_index(name='outcome_long_stay_admissions')

    cohort_with_outcomes = cohort_with_exposure.copy()
    cohort_with_outcomes = pd.merge(cohort_with_outcomes, emerg_counts, on=['nhs', 'period_id'], how='left')
    cohort_with_outcomes = pd.merge(cohort_with_outcomes, adm_counts, on=['nhs', 'period_id'], how='left')
    cohort_with_outcomes = pd.merge(cohort_with_outcomes, mha_counts, on=['nhs', 'period_id'], how='left')
    cohort_with_outcomes = pd.merge(cohort_with_outcomes, bed_day_sums, on=['nhs', 'period_id'], how='left')
    cohort_with_outcomes = pd.merge(cohort_with_outcomes, long_stay_counts, on=['nhs', 'period_id'], how='left')

    out_cols = ['outcome_emergencies', 'outcome_admissions', 'outcome_mha_admissions', 'outcome_bed_days', 'outcome_long_stay_admissions']
    cohort_with_outcomes[out_cols] = cohort_with_outcomes[out_cols].fillna(0).astype(int)


    # ==============================================================================
    # STEP 6: FEATURE ENGINEERING & RATES
    # ==============================================================================
    log("\n--- Step 6: Feature Engineering & Rate Calculation ---")

    cohort_sorted = cohort_with_outcomes.sort_values('exposure_start')
    demos_dynamic_sorted = demos_dynamic.sort_values('registration_start')

    # 1. Dynamic GP
    valid_gp_hist = demos_dynamic_sorted[demos_dynamic_sorted['gp'].notna()][['nhs', 'registration_start', 'gp']]
    cohort_w_gp = pd.merge_asof(
        cohort_sorted,
        valid_gp_hist,
        left_on='exposure_start',
        right_on='registration_start',
        by='nhs',
        direction='nearest'
    )
    cohort_w_gp.rename(columns={'gp': 'gp_at_exposure'}, inplace=True)
    cohort_w_gp.drop('registration_start', axis=1, inplace=True)

    # 2. Dynamic IMD
    valid_imd_hist = demos_dynamic_sorted[demos_dynamic_sorted['imd'].notna()][['nhs', 'registration_start', 'imd']]
    cohort_w_covars = pd.merge_asof(
        cohort_w_gp,
        valid_imd_hist,
        left_on='exposure_start',
        right_on='registration_start',
        by='nhs',
        direction='nearest'
    )
    cohort_w_covars.rename(columns={'imd': 'imd_at_exposure'}, inplace=True)
    cohort_w_covars.drop('registration_start', axis=1, inplace=True)

    # 3. Static Demographics
    exclude_cols = ['nhs', 'birth_year', 'dod', 'sex', 'ethnicity', 'gp', 'imd']
    extra_static_cols = [c for c in demos.columns if c not in exclude_cols]
    final_cohort = pd.merge(cohort_w_covars, demos[['nhs'] + extra_static_cols], on='nhs', how='left')

    # 4. Derived Features
    final_cohort['died_in_outcome'] = (
        (final_cohort['dod'] >= final_cohort['outcome_start']) &
        (final_cohort['dod'] <= final_cohort['outcome_end'])
    )

    # 6. Outcome Days at Risk
    risk_end_date = np.minimum(
        final_cohort['dod'].fillna(final_cohort['outcome_end']),
        final_cohort['outcome_end']
    )
    final_cohort['outcome_days_at_risk'] = (risk_end_date - final_cohort['outcome_start']).dt.days
    final_cohort['outcome_days_at_risk'] = final_cohort['outcome_days_at_risk'].clip(lower=0)
    final_cohort.drop('dod', axis=1, inplace=True)

    # 7. Calculate Rates
    exp_count_cols = ['total_consults', 'remote_consults', 'f2f_consults', 'missing_modality_consults']
    for col in exp_count_cols:
        final_cohort[f'{col}_rate'] = final_cohort[col] / final_cohort['active_days_in_exposure']

    out_count_cols = ['outcome_emergencies', 'outcome_admissions', 'outcome_bed_days', 'outcome_mha_admissions', 'outcome_long_stay_admissions']
    for col in out_count_cols:
        final_cohort[f'{col}_rate'] = final_cohort[col] / final_cohort['outcome_days_at_risk']

    # 8. Calculate Proportion Remote (UPDATED: Removed * 100)
    known_consults = final_cohort['total_consults'] - final_cohort['missing_modality_consults']
    final_cohort['pct_remote'] = np.where(
        known_consults > 0,
        (final_cohort['remote_consults'] / known_consults), 
        np.nan
    )

    # ==============================================================================
    # STEP 7: FINAL ORDERING
    # ==============================================================================
    desired_order = [
        'nhs', 'period_id', 
        'exposure_start', 'exposure_end', 
        'outcome_start', 'outcome_end',
        'total_consults', 'total_consults_rate',
        'remote_consults', 'remote_consults_rate',
        'f2f_consults', 'f2f_consults_rate',
        'missing_modality_consults', 'missing_modality_consults_rate',
        'pct_remote',
        'active_days_in_exposure',
        'outcome_days_at_risk', 
        'died_in_outcome',
        'outcome_emergencies', 'outcome_emergencies_rate',
        'outcome_admissions', 'outcome_admissions_rate',
        'outcome_bed_days', 'outcome_bed_days_rate',
        'outcome_mha_admissions', 'outcome_mha_admissions_rate',
        'outcome_long_stay_admissions', 'outcome_long_stay_admissions_rate',
        'sex', 'ethnicity', 
        'age_at_exposure', 'imd_at_exposure', 'gp_at_exposure',
        'anx_ever', 'dep_ever', 'smi_ever', 'af_ever', 'hf_ever', 'ihd_ever'
    ]

    final_analysis_cohort = final_cohort[desired_order]

    log("\n--- Processing Complete ---")
    log(f"Final Analysis Cohort Shape: {final_analysis_cohort.shape}")
    
    if not verbose:
        print(f"[Summary] Extraction Complete | SinglePeriod: {use_single_random_period}, SamplePostCheck: {sample_after_eligibility} | Input Patients: {len(master_patients):,} -> Final Cohort: {len(final_analysis_cohort):,}")
        
    return final_analysis_cohort

# # --- USAGE EXAMPLE ---
# final_df = generate_cohort(
#     demos, consults, emergencies, admissions, demos_dynamic, periods_df,
#     extraction_date=EXTRACTION_DATE,
#     use_single_random_period=True,
#     sample_after_eligibility=True,
#     random_seed=2025,
#     verbose=True,
#     long_stay_threshold_days=7  # Count admissions longer than 7 days
# )

# display(final_df.head(2).T)

### Normality and parametric testing
File sub-directory: ```distribution_checks/```
- Saves table to: ```distribution_assumptions.csv```

In [None]:
def check_distribution_assumptions(df, outcome_cols, results_dir=None, verbose=True):
    """
    Checks normality and overdispersion to justify Negative Binomial GEE.
    Saves a formatted table to CSV if results_dir is provided.
    """
    results_list = []
    
    # Setup Plotting Grid
    n_cols = len(outcome_cols)
    fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
    if n_cols == 1: axes = [axes]
    for i, col in enumerate(outcome_cols):
        data = df[col].dropna()
        
        # 1. Descriptive Stats
        mean_val = data.mean()
        var_val = data.var()
        skew_val = data.skew()
        
        # Overdispersion Ratio (phi)
        dispersion_ratio = var_val / mean_val if mean_val > 0 else 0
        # 2. Normality Test (D'Agostino's K^2 test)
        k2, p_val = stats.normaltest(data)
        
        # Collect results
        results_list.append({
            'Variable': col,
            'Mean': mean_val,
            'Variance': var_val,
            'Dispersion_Ratio': dispersion_ratio,
            'Skew': skew_val,
            'Normality_p_val': p_val
        })
        # 3. Visual Check (Histogram + KDE) - Display Only
        sns.histplot(data, kde=True, ax=axes[i], bins=30, color='skyblue')
        axes[i].set_title(f"{col}\n(Skew: {skew_val:.2f})")
        axes[i].set_xlabel("Count")
        
    plt.tight_layout()
    if verbose:
        plt.show()
    else:
        plt.close()
    
    # --- Process Results Table ---
    df_results = pd.DataFrame(results_list)
    
    # Create a formatted version for Display/Saving
    df_formatted = df_results.copy()
    
    # Format Floats
    df_formatted['Mean'] = df_formatted['Mean'].map('{:.2f}'.format)
    df_formatted['Variance'] = df_formatted['Variance'].map('{:.2f}'.format)
    df_formatted['Dispersion_Ratio'] = df_formatted['Dispersion_Ratio'].map('{:.2f}'.format)
    df_formatted['Skew'] = df_formatted['Skew'].map('{:.2f}'.format)
    
    # Format P-values for publication (< 0.001)
    def format_pval(p):
        if p < 0.001:
            return "< 0.001"
        else:
            return f"{p:.4f}"
            
    df_formatted['Normality_p_val'] = df_formatted['Normality_p_val'].apply(format_pval)
    
    if verbose:
        print("\n=== DISTRIBUTION ASSUMPTIONS CHECK ===")
        display(df_formatted)
        
        print("\nINTERPRETATION:")
        print("1. If 'Normality p-val' < 0.05, data is NOT normal (justifies non-parametric/GLM).")
        print("2. If 'Dispersion_Ratio' > 1.0, data is OVERDISPERSED (justifies Negative Binomial over Poisson).")
    
    # --- Save to CSV ---
    if results_dir:
        # --- Create output directory ---
        os.makedirs(results_dir, exist_ok=True)
            
        save_path = os.path.join(results_dir, 'distribution_assumptions.csv')
        df_formatted.to_csv(save_path, index=False)
        if verbose:
            print(f"\nSaved distribution table to: {save_path}")

# # Usage:
# check_distribution_assumptions(
#     final_df,
#     ['outcome_emergencies', 'outcome_admissions', 'outcome_bed_days', 'outcome_mha_admissions'],
#     results_dir=f"{RESULTS_DIR}/distribution_checks",
#     verbose=True
# )

### Missingness analysis
File sub-directory: ```missingness/```
- Saves table to: ```missingness_summary_counts.csv```
- Saves table to: ```missingness_table_pct_remote.csv```
- Saves table to: ```missingness_table_sex.csv```
- Saves table to: ```missingness_table_imd_at_exposure.csv```
- Saves table to: ```missingness_table_ethncity.csv```
- Saves figure to: ```gp_missingness_grid.png```

In [None]:
# ==========================================
# 1. HELPER FUNCTIONS
# ==========================================
def format_p_value(p):
    if pd.isna(p): return "-"
    if p < 0.001: return "<0.001 ***"
    if p < 0.01:  return f"{p:.3f} **"
    if p < 0.05:  return f"{p:.3f} *"
    return f"{p:.3f}"

def calculate_smd(series1, series2, is_categorical=False):
    s1 = series1.dropna()
    s2 = series2.dropna()
    if len(s1) == 0 or len(s2) == 0: return np.nan

    if is_categorical:
        p1 = s1.mean()
        p2 = s2.mean()
        var = (p1 * (1 - p1) + p2 * (1 - p2)) / 2
        if var == 0: return 0.0
        return (p1 - p2) / np.sqrt(var)
    else:
        m1, m2 = s1.mean(), s2.mean()
        v1, v2 = s1.var(), s2.var()
        if v1 == 0 and v2 == 0: return 0.0
        pooled_sd = np.sqrt((v1 + v2) / 2)
        return (m1 - m2) / pooled_sd

def get_numeric_summary(series):
    if len(series) == 0: return "-"
    med = series.median()
    q1 = series.quantile(0.25)
    q3 = series.quantile(0.75)
    return f"{med:.2f} [{q1:.2f}-{q3:.2f}]"

def create_missing_summary(df):
    missing_data = df.isnull().sum()
    missing_percent = 100 * df.isnull().sum() / len(df)
    summary_table = pd.concat([missing_data, missing_percent], axis=1, keys=['Missing (N)', 'Missing (%)'])
    summary_table = summary_table[summary_table['Missing (N)'] > 0].sort_values('Missing (N)', ascending=False)
    return summary_table.round(2)

# ==========================================
# 2. DETAILED PATTERN ANALYSIS
# ==========================================
def create_missingness_table(df, missing_col, columns_to_compare):
    """
    Creates a comparison table for a specific missing column against a list of covariates.
    columns_to_compare is now a required argument.
    """
    missing_mask = df[missing_col].isna()
    df_missing = df[missing_mask].copy()
    df_present = df[~missing_mask].copy()
    
    if len(df_missing) == 0: return None
    
    results = []
    
    expand_cols = ['ethnicity', 'imd_at_exposure', 'period_id', 'sex']
    
    for col in columns_to_compare:
        if col == missing_col or col not in df.columns: continue
        
        # --- LOGIC TYPE A: EXPANDED CATEGORICAL ---
        if col in expand_cols:
            try:
                s_missing = df[missing_col].isna()
                s_col = df[col].astype(str).replace('<NA>', 'nan')
                contingency = pd.crosstab(s_missing, s_col)
                if contingency.size > 0:
                    stat, p, dof, ex = stats.chi2_contingency(contingency)
                    p_str = format_p_value(p)
                else:
                    p_str = "-"
            except:
                p_str = "Err"

            results.append({
                'Characteristic': f"{col} (Categorical)",
                f'Known (N={len(df_present)})': '',
                f'Missing (N={len(df_missing)})': '',
                'P-value': p_str,
                'SMD': ''
            })
            
            valid_vals = df[col].dropna().unique()
            try: categories = sorted(valid_vals)
            except: categories = valid_vals
                
            for cat in categories:
                valid_pres = df_present[col].dropna()
                valid_miss = df_missing[col].dropna()
                
                s_pres_bin = (valid_pres == cat).astype(int)
                s_miss_bin = (valid_miss == cat).astype(int)
                
                count_pres = s_pres_bin.sum()
                pct_pres = s_pres_bin.mean() * 100 if len(s_pres_bin) > 0 else 0
                count_miss = s_miss_bin.sum()
                pct_miss = s_miss_bin.mean() * 100 if len(s_miss_bin) > 0 else 0
                
                smd = calculate_smd(s_pres_bin, s_miss_bin, is_categorical=True)
                
                results.append({
                    'Characteristic': f"  {cat}",
                    f'Known (N={len(df_present)})': f"{count_pres} ({pct_pres:.1f}%)",
                    f'Missing (N={len(df_missing)})': f"{count_miss} ({pct_miss:.1f}%)",
                    'P-value': '',
                    'SMD': f"{abs(smd):.3f}"
                })

        # --- LOGIC TYPE B: NUMERIC (Non-Parametric) ---
        elif pd.api.types.is_numeric_dtype(df[col]) and not pd.api.types.is_bool_dtype(df[col]):
            # Raw values (No scaling)
            val_pres = get_numeric_summary(df_present[col])
            val_miss = get_numeric_summary(df_missing[col])
            
            smd = calculate_smd(df_present[col], df_missing[col], is_categorical=False)
            
            try:
                c_pres = df_present[col].dropna()
                c_miss = df_missing[col].dropna()
                if len(c_pres) > 0 and len(c_miss) > 0:
                    if col == 'age_at_exposure':
                        stat, p = stats.ttest_ind(c_pres, c_miss, equal_var=False)
                    else:
                        stat, p = stats.mannwhitneyu(c_pres, c_miss)
                    p_str = format_p_value(p)
                else:
                    p_str = "-"
            except:
                p_str = "Err"
                
            results.append({
                'Characteristic': f"{col} (Median [IQR])",
                f'Known (N={len(df_present)})': val_pres,
                f'Missing (N={len(df_missing)})': val_miss,
                'P-value': p_str,
                'SMD': f"{abs(smd):.3f}"
            })

        # --- LOGIC TYPE C: BOOLEAN ---
        elif pd.api.types.is_bool_dtype(df[col]):
            valid_pres = df_present[col].dropna()
            valid_miss = df_missing[col].dropna()
            
            s_pres_bin = valid_pres.astype(int)
            s_miss_bin = valid_miss.astype(int)

            count_pres = s_pres_bin.sum()
            pct_pres = s_pres_bin.mean() * 100 if len(s_pres_bin) > 0 else 0
            count_miss = s_miss_bin.sum()
            pct_miss = s_miss_bin.mean() * 100 if len(s_miss_bin) > 0 else 0

            smd = calculate_smd(s_pres_bin, s_miss_bin, is_categorical=True)
            
            try:
                contingency = pd.crosstab(df[missing_col].isna(), df[col].fillna(-1))
                if contingency.size > 0:
                    stat, p, dof, ex = stats.chi2_contingency(contingency)
                    p_str = format_p_value(p)
                else:
                    p_str = "-"
            except:
                p_str = "Err"

            results.append({
                'Characteristic': f"{col} (True)",
                f'Known (N={len(df_present)})': f"{count_pres} ({pct_pres:.1f}%)",
                f'Missing (N={len(df_missing)})': f"{count_miss} ({pct_miss:.1f}%)",
                'P-value': p_str,
                'SMD': f"{abs(smd):.3f}"
            })

    return pd.DataFrame(results)

# ==========================================
# EXECUTION & SAVING
# ==========================================

def run_missingness_analysis(df, target_vars, columns_to_compare, 
                             verbose=True, results_dir=None):
    """
    Run missingness analysis.
    
    Required Arguments:
    - df: The dataframe
    - target_vars: List of columns containing missing values to analyze
    - columns_to_compare: List of covariates to compare against
    
    Optional Arguments:
    - verbose: Print output
    - results_dir: Directory to save CSVs (default None)
    """
    
    # Setup output directory
    if results_dir:
        os.makedirs(results_dir, exist_ok=True)
    
    # Summary table
    if verbose:
        print("="*60, "\n DATA COMPLETENESS SUMMARY\n", "="*60, sep='')
    summary_df = create_missing_summary(df)
    if verbose:
        display(summary_df)
    if results_dir:
        summary_df.to_csv(os.path.join(results_dir, "missingness_summary_counts.csv"))
    
    # Detailed patterns
    results = {'summary': summary_df}
    for target in target_vars:
        if target in df.columns and df[target].isna().sum() > 0:
            if verbose:
                print(f"\n{'='*60}\n PATTERN ANALYSIS: {target.upper()}\n{'='*60}")
            
            table = create_missingness_table(df, target, columns_to_compare)
            
            if table is not None:
                results[target] = table
                if verbose:
                    display(table)
                if results_dir:
                    table.to_csv(os.path.join(results_dir, f"missingness_table_{target}.csv"), index=False)
    
    return results

# # Usage:
# missingness_results = run_missingness_analysis(final_df, verbose=True, results_dir=f"{RESULTS_DIR}/missingness")

In [None]:
def analyze_gp_missingness_grid(df, targets, gp_col='gp_at_exposure', min_patients=50, verbose=True, results_dir=RESULTS_DIR):
    """
    Analyzes missing data rates across GPs for multiple targets.
    Produces a 2x2 grid of 'Caterpillar Plots'.
    """

    VARIABLE_RENAME = {
        'ethnicity':'Ethnicity', 
        'imd_at_exposure':'IMD', 
        'sex':'Sex', 
        'pct_remote':'Consultation Modality'
    }
    
    # Filter valid targets that exist in DF
    valid_targets = [t for t in targets if t in df.columns]
    n_targets = len(valid_targets)
    
    if n_targets == 0:
        print("No valid targets found in dataframe.")
        return

    # Setup Grid (2x2 if 4 items, or adjust dynamically)
    cols = 2
    rows = (n_targets + 1) // 2
    fig, axes = plt.subplots(rows, cols, figsize=(15, 6 * rows))
    axes = axes.flatten() # Flatten to 1D array for easy iteration
    
    if verbose: print(f"Generating GP Missingness Grid for: {valid_targets}...")

    for i, target_col in enumerate(valid_targets):
        ax = axes[i]
        
        # 1. Create Binary Missing Flag
        # For pct_remote, missing is NaN. For others, it's also NaN.
        is_missing = df[target_col].isna().astype(int)
        
        # 2. Group by GP
        gp_stats = pd.DataFrame({
            'gp': df[gp_col],
            'is_missing': is_missing
        }).groupby('gp').agg(
            total_patients=('is_missing', 'count'),
            missing_count=('is_missing', 'sum'),
            missing_pct=('is_missing', 'mean')
        ).reset_index()
        
        # Convert to %
        gp_stats['missing_pct'] = gp_stats['missing_pct'] * 100
        
        # 3. Filter small GPs
        gp_stats = gp_stats[gp_stats['total_patients'] >= min_patients].copy()
        
        # 4. Sort (Worst first)
        gp_stats = gp_stats.sort_values('missing_pct', ascending=False)
        
        # 5. Stats
        mean_miss = gp_stats['missing_pct'].mean()
        std_miss = gp_stats['missing_pct'].std()
        threshold = mean_miss + std_miss
        
        # 6. Plot on specific axis
        colors = ['red' if x > threshold else 'steelblue' for x in gp_stats['missing_pct']]
        
        sns.barplot(
            x='missing_pct', y='gp', data=gp_stats, 
            hue='gp', palette=colors, orient='h', legend=False,
            ax=ax
        )
        
        # Add Mean Line
        ax.axvline(mean_miss, color='black', linestyle='--', alpha=0.7)
        
        # Formatting
        ax.set_title(f'Missing {VARIABLE_RENAME[target_col]}', fontsize=14, fontweight='bold')
        ax.set_xlabel('% Missing')
        ax.set_ylabel('') # Remove label to save space
        ax.set_yticks([]) # Anonymize GP codes
        ax.grid(axis='x', alpha=0.3)
        
        # Add Comprehensive Info Box (Bottom Right)
        stats_text = (
            f"Mean: {mean_miss:.1f}%\n"
            f"SD: {std_miss:.1f}%\n"
            f"Range: {gp_stats['missing_pct'].min():.1f}% - {gp_stats['missing_pct'].max():.1f}%\n"
            f"────────────────\n"
            f"-- Line: Average\n"
            f"■ Red: > Mean + 1 SD"
        )
        
        ax.text(0.95, 0.05, stats_text, transform=ax.transAxes, 
                fontsize=10, verticalalignment='bottom', horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.95, edgecolor='gray'))

    # Hide empty subplots if any
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    
    # Setup output directory and save
    os.makedirs(results_dir, exist_ok=True)
    save_path = os.path.join(results_dir, "gp_missingness_grid.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if verbose:
        print(f"Grid saved to: {save_path}")
        plt.show()
    else:
        plt.close()

# ==========================================
# EXECUTION
# ==========================================
# analyze_gp_missingness_grid(final_df, ['ethnicity', 'imd_at_exposure', 'sex', 'pct_remote'], verbose=True, results_dir=f"{RESULTS_DIR}/missingness")

### Missing data and imputation
File sub-directory: ```imputation/```
- Saves text to: ```complete_case_summary.txt```
- Saves text to: ```variables_used.txt```
- Saves table to: ```imputation_stats_ethnicity.csv```
- Saves table to: ```imputation_stats_sex.csv```
- Saves table to: ```imputation_stats_imd_at_exposure.csv```
- Saves table to: ```imputation_stats_pct_remote.csv```

In [None]:
def handle_missing_data(
    df, 
    cols_to_impute, 
    targets_to_update,
    complete_case=False,
    keep_pct_remote_nan=False,
    verbose=True,
    n_datasets=5,        
    n_iterations=100,    
    mice_cycles=5,       
    random_state=42,
    results_dir=None
):
    """
    Handles missing data via MICE imputation or Complete Case Analysis.
    Saves metrics/summaries to results_dir if provided, but NOT the datasets.
    """
    
    # Create a working copy
    df_in = df.copy()
    
    # Ensure results directory exists if provided
    if results_dir and not os.path.exists(results_dir):
        os.makedirs(results_dir)
        if verbose: print(f"Created directory: {results_dir}")

    # --- NEW: Save Variables List to File ---
    if results_dir:
        vars_path = os.path.join(results_dir, 'variables_used.txt')
        with open(vars_path, 'w') as f:
            f.write("=== Variables Provided for Imputation ===\n")
            for col in cols_to_impute:
                # Check if the variable actually exists in the DF
                status = "OK" if col in df_in.columns else "MISSING IN DATAFRAME"
                f.write(f"{col} : {status}\n")
            
            f.write("\n=== Targets Being Updated ===\n")
            for col in targets_to_update:
                f.write(f"{col}\n")
        if verbose: print(f"Saved variables list to: {vars_path}")

    # ==========================================
    # PATH A: COMPLETE CASE ANALYSIS
    # ==========================================
    if complete_case:
        subset_cols = [c for c in targets_to_update if c in df_in.columns]
        
        if keep_pct_remote_nan and 'pct_remote' in subset_cols:
            subset_cols.remove('pct_remote')
            
        original_len = len(df_in)
        df_cc = df_in.dropna(subset=subset_cols).copy()
        rows_dropped = original_len - len(df_cc)
        
        # --- Report String ---
        report_str = (
            "=== COMPLETE CASE SUMMARY ===\n"
            f"Original rows: {original_len}\n"
            f"Rows dropped:  {rows_dropped} ({rows_dropped/original_len:.1%})\n"
            f"Remaining:     {len(df_cc)}\n"
        )
        if keep_pct_remote_nan:
            report_str += "Note: Rows with missing 'pct_remote' were retained.\n"
        
        if verbose:
            print("\n" + report_str)
            
        # --- Saving Logic ---
        if results_dir:
            # Save the Summary Report ONLY (No dataset)
            with open(os.path.join(results_dir, 'complete_case_summary.txt'), 'w') as f:
                f.write(report_str)
            if verbose: print(f"Saved Complete Case summary to: {results_dir}")
        
        return df_cc

    # ==========================================
    # PATH B: MICE IMPUTATION
    # ==========================================
    
    # 1. Prepare Model Data
    cols_to_use = [c for c in cols_to_impute if c in df_in.columns]
    df_model = df_in[cols_to_use].copy()
    
    categorical_cols = ['ethnicity', 'sex']
    for col in categorical_cols:
        if col in df_model.columns:
            df_model[col] = df_model[col].astype('category')

    if 'imd_at_exposure' in df_model.columns:
        df_model['imd_at_exposure'] = df_model['imd_at_exposure'].astype(float)

    if verbose:
        print(f"Running MICE (LightGBM) on {len(cols_to_use)} columns...")
        print(f"Settings: {n_datasets} datasets, {mice_cycles} cycles, {n_iterations} trees per model.")

    # 2. Run MICE
    kds = mf.ImputationKernel(
        df_model,
        datasets=n_datasets,
        save_all_iterations=True,
        random_state=random_state
    )
    
    kds.mice(iterations=mice_cycles, verbose=0, num_iterations=n_iterations) 

    # 3. Finalize Datasets
    imputed_datasets = {}
    
    if verbose: print(f"\nProcessing {n_datasets} imputed datasets (Internal Only)...")
    for i in range(n_datasets):
        df_completed = kds.complete_data(i)
        
        # Post-Processing
        if 'imd_at_exposure' in df_completed.columns:
            df_completed['imd_at_exposure'] = df_completed['imd_at_exposure'].round().clip(1, 10)
        if 'pct_remote' in df_completed.columns:
            df_completed['pct_remote'] = df_completed['pct_remote'].clip(0, 100)
            
        # Merge back
        df_temp = df_in.copy()
        for col in targets_to_update:
            if col in df_completed.columns:
                df_temp[col] = df_completed[col]
        
        imputed_datasets[i] = df_temp

    # 4. Reporting (Using Dataset 0)
    final_cohort_imputed = imputed_datasets[0]

    if verbose:
        print(f"\n=== IMPUTATION SUMMARY (Dataset 0 of {n_datasets}) ===")
    
    # --- Categorical Comparisons ---
    check_cols = ['ethnicity', 'sex', 'imd_at_exposure']
    for col in check_cols:
        if col in df_in.columns:
            try:
                # A. CALCULATION (Run always)
                before = df_in[col].value_counts(dropna=False).sort_index().rename("Before")
                after = final_cohort_imputed[col].value_counts(dropna=False).sort_index().rename("After")
                comp = pd.concat([before, after], axis=1).fillna(0).astype(int)
                comp['Change'] = comp['After'] - comp['Before']
                
                # RENAME NaNs for CSV clarity
                comp.index = comp.index.map(lambda x: 'Missing' if pd.isna(x) else x)

                # B. SAVING (Run if directory exists)
                if results_dir:
                    comp.to_csv(os.path.join(results_dir, f'imputation_stats_{col}.csv'))
                    
                # C. PRINTING (Run only if verbose)
                if verbose:
                    print(f"\n--- {col} Distribution ---")
                    try:
                        display(comp)
                    except:
                        print(comp)
                    
            except Exception as e:
                if verbose: print(f"Could not display/save {col}: {e}")

    # --- Continuous Statistics (pct_remote) ---
    if 'pct_remote' in df_in.columns:
        # A. CALCULATION (Run always)
        stats_before = df_in['pct_remote'].describe().rename("Before")
        stats_after = final_cohort_imputed['pct_remote'].describe().rename("After")
        
        miss_before = pd.Series({'missing_count': df_in['pct_remote'].isna().sum()}, name="Before")
        miss_after = pd.Series({'missing_count': final_cohort_imputed['pct_remote'].isna().sum()}, name="After")
        
        stats_compare = pd.concat([stats_before, stats_after], axis=1)
        stats_compare = pd.concat([stats_compare, pd.concat([miss_before, miss_after], axis=1).T]).T
        
        # B. SAVING (Run if directory exists)
        if results_dir:
            stats_compare.to_csv(os.path.join(results_dir, 'imputation_stats_pct_remote.csv'))

        # C. PRINTING (Run only if verbose)
        if verbose:
            print("\n--- pct_remote Statistics ---")
            try:
                display(stats_compare.round(2))
            except:
                print(stats_compare.round(2))

    return imputed_datasets

# ==========================================
# EXECUTION BLOCKS
# ==========================================

# cols_to_impute_list = [
#     # --- TARGETS ---
#     'sex', 'ethnicity', 'imd_at_exposure', 'pct_remote',
#     # --- PREDICTORS ---
#     'age_at_exposure',
#     'total_consults', 'total_consults_rate',
#     'remote_consults', 'remote_consults_rate',
#     'f2f_consults', 'f2f_consults_rate',
#     'missing_modality_consults', 'missing_modality_consults_rate',
#     'outcome_emergencies', 'outcome_emergencies_rate',
#     'outcome_admissions', 'outcome_admissions_rate',
#     'outcome_bed_days', 'outcome_bed_days_rate',
#     'outcome_mha_admissions', 'outcome_mha_admissions_rate',
#     'active_days_in_exposure',
#     'outcome_days_at_risk',
#     'died_in_outcome',
#     'anx_ever', 'dep_ever', 'smi_ever', 
#     'af_ever', 'hf_ever', 'ihd_ever'
# ]

# # 1. RUN COMPLETE CASE (Dropping pct_remote NaNs)
# # -----------------------------------------------
# print("Running Complete Case Analysis...")
# df_complete_case = handle_missing_data(
#     df=final_df,
#     cols_to_impute=cols_to_impute_list, 
#     targets_to_update=['ethnicity', 'imd_at_exposure', 'sex', 'pct_remote'],
#     complete_case=True,
#     keep_pct_remote_nan=False, 
#     verbose=True,
#     results_dir=f"{RESULTS_DIR}/imputation"
# )

# # 2. RUN IMPUTATION (Standard 5 datasets)
# # -----------------------------------------------
# print("\n" + "="*40 + "\n")
# print("Running MICE Imputation...")
# imputed_dict = handle_missing_data(
#     df=final_df,
#     cols_to_impute=cols_to_impute_list,
#     targets_to_update=['ethnicity', 'imd_at_exposure', 'sex', 'pct_remote'],
#     complete_case=False,
#     keep_pct_remote_nan=False,
#     verbose=True,
#     n_datasets=5,
#     n_iterations=25, # 100 for final run
#     mice_cycles=5,    
#     results_dir=f"{RESULTS_DIR}/imputation"
# )

### Create baseline characteristics tables
File sub-directory: ```descriptive/```
- Saves table to: ```baseline_characteristics_initial_cohort.csv```
- Saves table to: ```baseline_characteristics_sample_imputed_cohort.csv```

In [None]:
def generate_table_one(df, save_dir=None, filename='Table1_Baseline.csv', 
                       groupby_mode='remote', show_pval=True, 
                       remote_as_continuous=False, verbose=True,
                       selected_columns=None):
    """
    Generates a publication-ready Table 1 with custom filename support.
    """
    # 1. SETUP & CLEANING
    df_table = df.copy()

    # Clean missing values
    replace_vals = ['None', 'none', 'NaN', 'nan', 'Missing', 'missing']
    df_table['sex'] = df_table['sex'].replace(replace_vals, np.nan)
    df_table['ethnicity'] = df_table['ethnicity'].replace(replace_vals, np.nan)
    df_table = df_table.dropna(subset=['pct_remote'])

    # 2. CREATE BINS (Trimodal Split)
    bins = [-0.1, 0.0001, 0.999, 1.1]
    labels = ['Never Remote (0%)', 'Hybrid Usage', 'Fully Remote (100%)']
    df_table['Remote_Group'] = pd.cut(df_table['pct_remote'], bins=bins, labels=labels).astype(str)

    # 3. CONVERT BOOLEANS
    bool_cols = ['anx_ever', 'dep_ever', 'smi_ever', 'af_ever', 'hf_ever', 'ihd_ever']
    for col in bool_cols:
        df_table[col] = df_table[col].replace({True: 'Yes', False: 'No'})

    # 4. CONFIGURATION (mode-dependent)
    if groupby_mode == 'remote':
        groupby_col = 'Remote_Group'
        columns = [
            'age_at_exposure', 'sex', 'ethnicity', 'imd_at_exposure', 
            'total_consults', 'period_id',           
            'anx_ever', 'dep_ever', 'smi_ever', 
            'af_ever', 'hf_ever', 'ihd_ever'
        ]
        categorical = [
            'sex', 'ethnicity', 'period_id', 
            'anx_ever', 'dep_ever', 'smi_ever', 
            'af_ever', 'hf_ever', 'ihd_ever'
        ]
        nonnormal = ['imd_at_exposure', 'total_consults']
        
    elif groupby_mode == 'period':
        groupby_col = 'period_id'
        
        if remote_as_continuous:
            remote_var = 'pct_remote'
            columns = [
                'age_at_exposure', 'sex', 'ethnicity', 'imd_at_exposure', 
                'total_consults', 'pct_remote',           
                'anx_ever', 'dep_ever', 'smi_ever', 
                'af_ever', 'hf_ever', 'ihd_ever'
            ]
            categorical = [
                'sex', 'ethnicity', 
                'anx_ever', 'dep_ever', 'smi_ever', 
                'af_ever', 'hf_ever', 'ihd_ever'
            ]
            nonnormal = ['imd_at_exposure', 'total_consults', 'pct_remote']
        else:
            remote_var = 'Remote_Group'
            columns = [
                'age_at_exposure', 'sex', 'ethnicity', 'imd_at_exposure', 
                'total_consults', 'Remote_Group',           
                'anx_ever', 'dep_ever', 'smi_ever', 
                'af_ever', 'hf_ever', 'ihd_ever'
            ]
            categorical = [
                'sex', 'ethnicity', 'Remote_Group', 
                'anx_ever', 'dep_ever', 'smi_ever', 
                'af_ever', 'hf_ever', 'ihd_ever'
            ]
            nonnormal = ['imd_at_exposure', 'total_consults']
    else:
        raise ValueError("groupby_mode must be 'remote' or 'period'")

    rename_dict = {
        'age_at_exposure': 'Age (years)',
        'imd_at_exposure': 'IMD Decile (Deprivation)',
        'period_id': 'Study Period',
        'total_consults': 'Total Consultations',
        'anx_ever': 'History of Anxiety',
        'dep_ever': 'History of Depression',
        'smi_ever': 'History of SMI',
        'af_ever': 'History of Atrial Fib',
        'hf_ever': 'History of Heart Failure',
        'ihd_ever': 'History of IHD',
        'sex': 'Sex',
        'ethnicity': 'Ethnicity',
        'Remote_Group': 'Remote Consult Usage',
        'pct_remote': 'Remote Consult % (continuous)'
    }

    # Force 'Yes' to top and limit to 1 row
    order = {k: ['Yes', 'No'] for k in bool_cols}
    limit = {k: 1 for k in bool_cols}

    # --- UPDATED FILTER LOGIC START ---
    if selected_columns:
        # 1. Filter the main columns list
        columns = [c for c in columns if c in selected_columns]
        
        # 2. Filter config lists
        categorical = [c for c in categorical if c in columns]
        nonnormal = [c for c in nonnormal if c in columns]
        
        # 3. Filter order and limit to remove keys for columns we just deleted
        order = {k: v for k, v in order.items() if k in columns}
        limit = {k: v for k, v in limit.items() if k in columns}
    # --- UPDATED FILTER LOGIC END ---

    # 5. GENERATE TABLE
    mytable = TableOne(
        df_table, 
        columns=columns, 
        categorical=categorical, 
        groupby=groupby_col, 
        nonnormal=nonnormal, 
        rename=rename_dict, 
        limit=limit,        
        order=order,        
        pval=show_pval,
        missing=False,      
        include_null=False, 
        sort=False,         
        overall=True,       
        smd=False            
    )

    # 6. POST-PROCESSING
    final_output = mytable.tableone.copy()

    if isinstance(final_output.columns, pd.MultiIndex):
        final_output.columns = final_output.columns.get_level_values(1)

    if groupby_mode == 'remote':
        priority_cols = ['Overall', 'Never Remote (0%)', 'Hybrid Usage', 'Fully Remote (100%)']
    elif groupby_mode == 'period':
        period_ids = sorted(df_table['period_id'].dropna().unique())
        priority_cols = ['Overall'] + [str(pid) for pid in period_ids]
    
    all_cols = list(final_output.columns)
    new_order = [c for c in priority_cols if c in all_cols] + \
                [c for c in all_cols if c not in priority_cols]
    final_output = final_output[new_order]

    if show_pval:
        def add_stars(val):
            try:
                if isinstance(val, str) and '<' in val: return val + '***'
                p = float(val)
                if p < 0.001: return f"{p:.3f}***"
                if p < 0.01:  return f"{p:.3f}**"
                if p < 0.05:  return f"{p:.3f}*"
                return f"{p:.3f}"
            except:
                return val

        if 'P-Value' in final_output.columns:
            final_output['P-Value'] = final_output['P-Value'].apply(add_stars)

    # 7. DISPLAY & SAVE
    if verbose: display(final_output)

    if save_dir is None:
        save_dir = '.' 
    
    save_path = os.path.join(save_dir, filename)
    final_output.to_csv(save_path)
    if verbose: print(f"✅ Table saved successfully to: {save_path}")
    
    return final_output

# # generate table 1 for final (non-imputed dataset)
# table1_non_impute = generate_table_one(
#     final_df, 
#     f"{RESULTS_DIR}/descriptive", 
#     filename='baseline_characteristics_initial_cohort.csv', 
#     verbose=True,
#     groupby_mode='period',
#     show_pval=False,
#     remote_as_continuous=True
# )

# # generate table 1 for (one sample) imputed dataset
# sampled_imputed_dataset = imputed_dict[0].copy()
# table1_impute = generate_table_one(
#     sampled_imputed_dataset, 
#     f"{RESULTS_DIR}/descriptive", 
#     filename='baseline_characteristics_sample_imputed_cohort.csv', 
#     verbose=True,
#     groupby_mode='period',
#     show_pval=False,
#     remote_as_continuous=True
# )

### Alpha parameter refinement
File sub-directory: ```regression/```
- Saves table to: ```gee_CC_alpha_optimisation_results.csv```
- Saves table to: ```gee_MICE_alpha_optimisation_results_pooled.csv```

In [None]:
def optimize_gee_alpha(
    df,
    outcomes,
    alphas_to_test,
    formula_template,
    offset_col='outcome_days_at_risk',
    grouping_col='gp_at_exposure',
    verbose=True,
    save_dir=None
):
    """
    Optimize alpha parameter for Negative Binomial GEE models using QIC.
    
    Parameters:
    -----------
    df : pd.DataFrame or dict
        Single dataframe for complete case analysis, or dict of dataframes for 
        multiple imputation (will pool QIC values across imputations)
    outcomes : list
        List of outcome variable names to test
    alphas_to_test : list
        List of alpha values to test for each outcome (must be > 0)
    formula_template : str
        Formula template with {outcome} placeholder.
        Example: "{outcome} ~ pct_remote_10 + total_consults + ..."
    offset_col : str, default='outcome_days_at_risk'
        Column name for offset variable (will be log-transformed)
    grouping_col : str, default='gp_at_exposure'
        Column name for grouping variable in GEE
    verbose : bool, default=True
        If True, displays graphs and results table. If False, suppresses all output.
    save_dir : str, optional
        Directory path to save QIC results table as CSV. If None, no file is saved.
        Files will be labeled with "_CC" (complete case) or "_MICE" (multiple imputation).
    
    Returns:
    --------
    dict
        Dictionary mapping each outcome to its optimal alpha value
        
    Note:
    -----
    Alpha cannot be zero - it must be positive. Alpha = 0 would correspond to 
    a Poisson distribution (no overdispersion), while alpha > 0 allows for 
    overdispersion in the Negative Binomial model.
    
    When df is a dict (multiple imputation), QIC values are averaged across all
    imputed datasets before selecting the optimal alpha. Only alphas with successful
    convergence are considered - models that fail to converge are excluded.
    """
    # Validate alpha values
    if any(a <= 0 for a in alphas_to_test):
        raise ValueError("All alpha values must be positive (> 0). Alpha = 0 would be Poisson, not Negative Binomial.")
    
    # Determine if we're doing multiple imputation
    is_imputed = isinstance(df, dict)
    analysis_type = "MICE" if is_imputed else "CC"
    
    # Validate save directory
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
    
    # Initialize storage
    optimal_alphas = {}
    convergence_failures = []
    other_failures = []
    
    if is_imputed:
        # Multiple imputation: pool QICs across all imputed datasets
        n_imputations = len(df)
        
        qic_by_imputation = {imp_idx: {outcome: {} for outcome in outcomes} 
                            for imp_idx in df.keys()}
        
        # Run optimization on each imputed dataset
        for imp_idx, imp_df in df.items():
            working_df = imp_df.copy()
            
            # Create log offset
            working_df['log_offset'] = np.log(working_df[offset_col].replace(0, 1e-5))
            
            # Create pct_remote_10 if not present
            if 'pct_remote_10' not in working_df.columns:
                if 'pct_remote' in working_df.columns:
                    working_df['pct_remote_10'] = working_df['pct_remote'] * 10
                else:
                    raise ValueError("'pct_remote' column not found in dataframe")
            
            # Test each outcome and alpha for this imputation
            for outcome in outcomes:
                current_formula = formula_template.format(outcome=outcome)
                
                for alpha in alphas_to_test:
                    try:
                        model = smf.gee(
                            formula=current_formula,
                            data=working_df,
                            groups=working_df[grouping_col],
                            offset=working_df['log_offset'],
                            family=sm.families.NegativeBinomial(alpha=alpha),
                            cov_struct=sm.cov_struct.Exchangeable()
                        )
                        
                        # Catch convergence warnings
                        with warnings.catch_warnings(record=True) as w:
                            warnings.simplefilter("always")
                            res = model.fit()
                            
                            # Check for convergence issues
                            convergence_failed = any(
                                "convergence" in str(warning.message).lower() or
                                "iteration limit" in str(warning.message).lower()
                                for warning in w
                            )
                            
                            if convergence_failed:
                                failure_key = (outcome, alpha)
                                if failure_key not in convergence_failures:
                                    convergence_failures.append(failure_key)
                                qic_by_imputation[imp_idx][outcome][alpha] = np.nan
                            else:
                                qic, _ = res.qic(scale=1.0)
                                qic_by_imputation[imp_idx][outcome][alpha] = qic
                            
                    except Exception as e:
                        failure_key = (outcome, alpha, str(e))
                        if failure_key not in other_failures:
                            other_failures.append(failure_key)
                        qic_by_imputation[imp_idx][outcome][alpha] = np.nan
        
        # Pool QICs by averaging across imputations
        qic_matrix = {outcome: {} for outcome in outcomes}
        for outcome in outcomes:
            for alpha in alphas_to_test:
                qic_values = [qic_by_imputation[i][outcome].get(alpha, np.nan) 
                             for i in df.keys()]
                valid_qics = [q for q in qic_values if not np.isnan(q)]
                
                if valid_qics:
                    qic_matrix[outcome][alpha] = np.mean(valid_qics)
                else:
                    qic_matrix[outcome][alpha] = np.nan
        
    else:
        # Single dataset (complete case)
        working_df = df.copy()
        
        # Create log offset
        working_df['log_offset'] = np.log(working_df[offset_col].replace(0, 1e-5))
        
        # Create pct_remote_10 if not present
        if 'pct_remote_10' not in working_df.columns:
            if 'pct_remote' in working_df.columns:
                working_df['pct_remote_10'] = working_df['pct_remote'] * 10
            else:
                raise ValueError("'pct_remote' column not found in dataframe")
        
        qic_matrix = {outcome: {} for outcome in outcomes}
        
        # Main loop through outcomes
        for outcome in outcomes:
            current_formula = formula_template.format(outcome=outcome)
            
            # Test each alpha
            for alpha in alphas_to_test:
                try:
                    model = smf.gee(
                        formula=current_formula,
                        data=working_df,
                        groups=working_df[grouping_col],
                        offset=working_df['log_offset'],
                        family=sm.families.NegativeBinomial(alpha=alpha),
                        cov_struct=sm.cov_struct.Exchangeable()
                    )
                    
                    # Catch convergence warnings
                    with warnings.catch_warnings(record=True) as w:
                        warnings.simplefilter("always")
                        res = model.fit()
                        
                        # Check for convergence issues
                        convergence_failed = any(
                            "convergence" in str(warning.message).lower() or
                            "iteration limit" in str(warning.message).lower()
                            for warning in w
                        )
                        
                        if convergence_failed:
                            convergence_failures.append((outcome, alpha))
                            qic_matrix[outcome][alpha] = np.nan
                        else:
                            qic, _ = res.qic(scale=1.0)
                            qic_matrix[outcome][alpha] = qic
                        
                except Exception as e:
                    other_failures.append((outcome, alpha, str(e)))
                    qic_matrix[outcome][alpha] = np.nan
    
    # Find optimal alphas (only from successfully converged models)
    for outcome in outcomes:
        valid_qics = {k: v for k, v in qic_matrix[outcome].items() if not np.isnan(v)}
        
        if valid_qics:
            best_alpha = min(valid_qics, key=valid_qics.get)
            optimal_alphas[outcome] = best_alpha
        else:
            if verbose:
                print(f"⚠️  WARNING: Could not determine alpha for {outcome} (All models failed)")
            optimal_alphas[outcome] = None  # Don't use a fallback - make it explicit
    
    if verbose:
        # Create combined plot with all outcomes
        n_outcomes = len(outcomes)
        n_cols = 2
        n_rows = (n_outcomes + 1) // 2
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4 * n_rows))
        if n_outcomes == 1:
            axes = np.array([axes])
        axes = axes.flatten()
        
        for idx, outcome in enumerate(outcomes):
            valid_qics = {k: v for k, v in qic_matrix[outcome].items() if not np.isnan(v)}
            
            if valid_qics:
                axes[idx].plot(list(valid_qics.keys()), list(valid_qics.values()), 
                              marker='o', linewidth=2, markersize=8)
                best_alpha = optimal_alphas.get(outcome)
                if best_alpha is not None:
                    axes[idx].axvline(best_alpha, color='red', linestyle='--', alpha=0.5, 
                                     label=f'Best α={best_alpha}')
                title_suffix = " (pooled)" if is_imputed else ""
                axes[idx].set_title(f"{outcome}\nBest Alpha: {best_alpha}{title_suffix}", 
                                   fontsize=11, fontweight='bold')
                axes[idx].set_xlabel("Alpha", fontsize=10)
                axes[idx].set_ylabel("QIC" + (" (pooled)" if is_imputed else ""), fontsize=10)
                axes[idx].grid(True, alpha=0.3)
                axes[idx].legend()
            else:
                axes[idx].text(0.5, 0.5, f'{outcome}\nAll models failed', 
                              ha='center', va='center', transform=axes[idx].transAxes)
                axes[idx].set_xlabel("Alpha", fontsize=10)
                axes[idx].set_ylabel("QIC", fontsize=10)
        
        # Hide extra subplots if odd number of outcomes
        for idx in range(n_outcomes, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Display failure summary
        if convergence_failures or other_failures:
            print("\n" + "="*50)
            print("FAILURE SUMMARY:")
            print("="*50)
            
            if convergence_failures:
                print("\n⚠️  Convergence failures (excluded from recommendations):")
                # Group by outcome
                conv_by_outcome = {}
                for item in convergence_failures:
                    outcome = item[0]
                    alpha = item[1]
                    if outcome not in conv_by_outcome:
                        conv_by_outcome[outcome] = []
                    if alpha not in conv_by_outcome[outcome]:
                        conv_by_outcome[outcome].append(alpha)
                
                for outcome, alphas in conv_by_outcome.items():
                    alphas_str = ", ".join([f"α={a}" for a in sorted(alphas)])
                    print(f"  {outcome}: {alphas_str}")
            
            if other_failures:
                print("\n❌ Other failures:")
                for outcome, alpha, error in other_failures:
                    print(f"  {outcome}, α={alpha}: {error[:80]}")
        
        # Display results table
        print("\n" + "="*50)
        print(f"QIC RESULTS TABLE ({analysis_type}):")
        print("="*50)
        results_df = pd.DataFrame(qic_matrix)
        results_df.index.name = 'alpha'
        display(results_df.round(0))
        
        # Display final recommendations
        print("\n" + "="*50)
        print(f"FINAL RECOMMENDATION ({analysis_type}):")
        print("="*50)
        print("Use these alphas for your final models:")
        for outcome, alpha in optimal_alphas.items():
            if alpha is not None:
                print(f"  {outcome}: {alpha}")
            else:
                print(f"  {outcome}: NO VALID ALPHA (all models failed)")
        print("="*50)
    
    # Save results table
    if save_dir is not None:
        results_df = pd.DataFrame(qic_matrix)
        results_df.index.name = 'alpha'
        results_path = os.path.join(save_dir, f'gee_{analysis_type}_alpha_optimization_results{'_pooled' if (analysis_type == 'MICE') else ''}.csv')
        results_df.to_csv(results_path)
    
    return optimal_alphas


# Example usage - Complete Case
# optimal_alphas_cc = optimize_gee_alpha(
#     df=complete_case_df,
#     outcomes=OUTCOMES,
#     alphas_to_test=alphas_to_test,
#     formula_template=formula_template,
#     save_dir=f"{RESULTS_DIR}/regression",
#     verbose=True
# )

# Example usage - Multiple Imputation (pooled)
# optimal_alphas_mice = optimize_gee_alpha(
#     df=imputed_dict,  # Pass the entire dict
#     outcomes=OUTCOMES,
#     alphas_to_test=alphas_to_test,
#     formula_template=formula_template,
#     save_dir=f"{RESULTS_DIR}/regression",
#     verbose=True
# )

### Perform regression
File sub-directory: ```regression/```

In [None]:
import scipy.stats as stats

# ==============================================================================
# 1. HELPER: DHARMa DIAGNOSTICS
# ==============================================================================
def _save_dharma_diagnostics(models_dict, df, alpha_dict, save_dir, mode_label, verbose=True, show_plots=False, n_sims=250):
    """
    Generates DHARMa residual plots (QQ and Res-vs-Pred).
    """

    OUTCOME_RENAME = {
        'outcome_emergencies':'Emergency contacts', 
        'outcome_admissions':'Psychiatric hospital admissions', 
        'outcome_bed_days':'Inpatient bed-days', 
        'outcome_mha_admissions':'MHA admissions', 
        # 'outcome_long_stay_admissions':'Long-stay psychiatric hospital admissions' 
    }
    
    outcomes = list(models_dict.keys())
    n_outcomes = len(outcomes)
    
    if n_outcomes == 0:
        return

    # Calculate grid dimensions
    n_cols = math.ceil(math.sqrt(n_outcomes))
    n_rows = math.ceil(n_outcomes / n_cols)
    
    # Create figures
    fig_qq, axes_qq = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
    fig_rvp, axes_rvp = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
    
    # Flatten axes
    if n_outcomes > 1:
        axes_qq = axes_qq.flatten()
        axes_rvp = axes_rvp.flatten()
    else:
        axes_qq = [axes_qq]
        axes_rvp = [axes_rvp]
        
    if verbose:
        print(f"   Generating DHARMa diagnostics ({mode_label}) for {n_outcomes} outcomes...")
    
    for i, outcome in enumerate(outcomes):
        model = models_dict[outcome]
        alpha = alpha_dict.get(outcome, 1.0)
        
        # Simulate
        mu = model.fittedvalues
        y_observed = df[outcome].values
        n_param = 1.0 / alpha
        p_param = 1.0 / (1.0 + alpha * mu)
        n_obs = len(y_observed)
        
        simulated_responses = np.zeros((n_obs, n_sims))
        for s in range(n_sims):
            simulated_responses[:, s] = stats.nbinom.rvs(n=n_param, p=p_param)
            
        # Dithering
        obs_dithered = y_observed[:, None] + np.random.uniform(-0.5, 0.5, (n_obs, 1))
        sim_dithered = simulated_responses + np.random.uniform(-0.5, 0.5, (n_obs, n_sims))
        
        rank = np.sum(sim_dithered <= obs_dithered, axis=1)
        residuals = rank / n_sims
        
        # KS Test
        ks_stat, ks_pval = stats.kstest(residuals, 'uniform')
        title_text = f"{OUTCOME_RENAME[outcome]}\nKS p={ks_pval:.3f}"
        
        # Plot QQ
        ax_q = axes_qq[i]
        sm.qqplot(residuals, dist=stats.uniform, line='45', ax=ax_q)
        ax_q.set_title(title_text, fontsize=10, fontweight='bold')
        ax_q.set_xlabel("Theoretical", fontsize=9)
        ax_q.set_ylabel("Observed", fontsize=9)
        
        # Plot Res vs Pred
        ax_r = axes_rvp[i]
        ax_r.scatter(mu, residuals, alpha=0.2, s=10, color='black')
        ax_r.axhline(0.5, color='red', linestyle='--', linewidth=1)
        ax_r.axhline(0.25, color='gray', linestyle=':', alpha=0.5)
        ax_r.axhline(0.75, color='gray', linestyle=':', alpha=0.5)
        ax_r.set_xscale('log')
        ax_r.set_title(f"{OUTCOME_RENAME[outcome]}", fontsize=10, fontweight='bold')
        ax_r.set_xlabel("Predicted (Log)", fontsize=9)
        ax_r.set_ylabel("Scaled Residual", fontsize=9)
        ax_r.set_ylim(-0.05, 1.05)

    # Cleanup
    for j in range(i + 1, len(axes_qq)):
        axes_qq[j].axis('off')
        axes_rvp[j].axis('off')

    # Save
    fig_qq.suptitle(f"Supplementary Figure A ({mode_label}): Uniformity Check (QQ Plots)", fontsize=14, y=0.99)
    fig_qq.tight_layout()
    qq_path = os.path.join(save_dir, f"gee_{mode_label}_dharma_qq_plots.png")
    fig_qq.savefig(qq_path, dpi=300)
    
    fig_rvp.suptitle(f"Supplementary Figure B ({mode_label}): Dispersion Check (Res vs Pred)", fontsize=14, y=0.99)
    fig_rvp.tight_layout()
    rvp_path = os.path.join(save_dir, f"gee_{mode_label}_dharma_residuals_vs_predicted.png")
    fig_rvp.savefig(rvp_path, dpi=300)
    
    if verbose:
        print(f"   ✓ Saved QQ plots to: {qq_path}")
        print(f"   ✓ Saved Res/Pred plots to: {rvp_path}")
    
    if not show_plots:
        plt.close(fig_qq)
        plt.close(fig_rvp)
    else:
        plt.show()


# ==============================================================================
# 2. CORE: GEE ANALYSIS (Supports Pooling & Diagnostics)
# ==============================================================================
def run_gee_analysis(
    data,
    outcomes,
    formula_template,
    exposure_var='pct_remote_10',
    alpha_dict=None,
    fixed_alpha=2.0,
    verbose=True,
    full_output=True,
    save_tables=False,
    results_dir=None,
    save_diagnostics=False,
    show_plots=False
):
    """
    Core function to run GEE. Handles both Complete Case (Single) and MICE (Pooled).
    """
    if (save_tables or save_diagnostics) and results_dir is None:
        raise ValueError("results_dir required for saving")
    if (save_tables or save_diagnostics) and not os.path.exists(results_dir):
        os.makedirs(results_dir)
        
    if alpha_dict is None: alpha_dict = {}

    # Detect Mode
    if isinstance(data, dict):
        mode = 'MICE'
        datasets = data
        n_datasets = len(datasets)
        mode_label = 'MICE'
    elif isinstance(data, pd.DataFrame):
        mode = 'Complete Case'
        datasets = {0: data}
        n_datasets = 1
        mode_label = 'CC'
    else:
        raise TypeError("data must be DataFrame or dict")
    
    if verbose:
        print(f"\n{'='*60}\nGEE CORE ANALYSIS | Mode: {mode}\n{'='*60}")
    
    # Storage
    summary_results = []
    all_coefficients = {}
    all_fit_stats = []
    all_mice_diagnostics = [] # For MICE mode
    all_models = {} 

    # --- MAIN LOOP ---
    for outcome in outcomes:
        outcome_alpha = alpha_dict.get(outcome, fixed_alpha)
        
        # Pooled Storage
        all_params, all_vcovs, fit_stats_list = [], [], []
        param_names = None
        
        for i, (key, df) in enumerate(datasets.items()):
            try:
                working_df = df.copy()
                working_df['log_offset'] = np.log(working_df['outcome_days_at_risk'].replace(0, 1e-5))
                if 'pct_remote_10' not in working_df.columns:
                    working_df['pct_remote_10'] = working_df['pct_remote'] * 10
                
                # Fit
                formula = formula_template.format(outcome=outcome)
                # with warnings.catch_warnings():
                #     warnings.simplefilter("ignore")
                #     model = smf.gee(
                #         formula=formula, data=working_df, groups=working_df['gp_at_exposure'],
                #         offset=working_df['log_offset'],
                #         family=sm.families.NegativeBinomial(alpha=outcome_alpha),
                #         cov_struct=sm.cov_struct.Exchangeable()
                #     )
                #     fit = model.fit()
                model = smf.gee(
                    formula=formula, data=working_df, groups=working_df['gp_at_exposure'],
                    offset=working_df['log_offset'],
                    family=sm.families.NegativeBinomial(alpha=outcome_alpha),
                    cov_struct=sm.cov_struct.Exchangeable()
                )
                fit = model.fit()

                # Save Model for Diagnostics (Dataset 0 only)
                if save_diagnostics and i == 0:
                    all_models[outcome] = fit
                
                if param_names is None: param_names = fit.params.index.tolist()
                
                all_params.append(fit.params.values)
                all_vcovs.append(fit.cov_params().values)
                
                # Fit Stats
                try:
                    qic_vals = fit.qic(scale=fit.scale)
                    qic, qicu = qic_vals[0], qic_vals[1]
                except:
                    qic, qicu = np.nan, np.nan
                    
                fit_stats_list.append({
                    'n_obs': int(fit.nobs), 'n_groups': len(working_df['gp_at_exposure'].unique()),
                    'scale': fit.scale, 'qic': qic, 'qicu': qicu
                })
                
            except Exception as e:
                if verbose: print(f"   Dataset {i} failed: {e}")
                continue
        
        # --- POOLING (Rubin's Rules) ---
        if not all_params: continue
        
        all_params = np.array(all_params)
        all_vcovs = np.array(all_vcovs)
        m = len(all_params)
        
        if mode == 'Complete Case' or m == 1:
            pooled_params = all_params[0]
            pooled_se = np.sqrt(np.diag(all_vcovs[0]))
            df_dof = 1000 
            fmi, riv = np.nan, np.nan
        else:
            Q_bar = np.mean(all_params, axis=0)
            U_bar = np.mean(all_vcovs, axis=0)
            B = np.cov(all_params.T, ddof=1)
            if B.ndim == 0: B = np.array([[B]])
            T = U_bar + (1 + 1/m) * B
            pooled_params = Q_bar
            pooled_se = np.sqrt(np.diag(T))
            
            # MICE Stats
            r = (1 + 1/m) * np.diag(B) / np.maximum(np.diag(U_bar), 1e-10)
            fmi = (r + 2/(np.maximum(r, 1e-10) + 3)) / (r + 1)
            riv = r
            df_dof = (m - 1) * (1 + 1/np.maximum(r, 1e-10))**2

        # Inference
        t_stats = pooled_params / pooled_se
        p_vals = 2 * (1 - stats.t.cdf(np.abs(t_stats), df=np.minimum(df_dof, 1000)))
        ci_low = pooled_params - 1.96 * pooled_se
        ci_high = pooled_params + 1.96 * pooled_se

        # Build Results Tables
        coef_df = pd.DataFrame({
            'Coefficient': param_names, 'Beta': pooled_params, 'SE': pooled_se,
            'IRR': np.exp(pooled_params), 'IRR_CI_Low': np.exp(ci_low), 'IRR_CI_High': np.exp(ci_high),
            'P_Value': p_vals, 'Significant': ['*' if p < 0.05 else '' for p in p_vals]
        })
        
        # MICE Diagnostics (FMI/RIV)
        if mode == 'MICE': 
            coef_df['FMI'] = fmi
            coef_df['RIV'] = riv
            
            # Store for MICE diagnostic file
            mice_diag = pd.DataFrame({
                'Coefficient': param_names,
                'FMI': fmi,
                'RIV': riv,
                'Outcome': outcome
            })
            all_mice_diagnostics.append(mice_diag)
        
        all_coefficients[outcome] = coef_df
        
        # Summary Row (Primary Exposure)
        try:
            exp_idx = param_names.index(exposure_var)
            summary_results.append({
                'Outcome': outcome, 'IRR': np.exp(pooled_params[exp_idx]),
                'CI_Low': np.exp(ci_low[exp_idx]), 'CI_High': np.exp(ci_high[exp_idx]),
                'P_Value': p_vals[exp_idx], 'Significant': '*' if p_vals[exp_idx] < 0.05 else '',
                'Method': mode
            })
        except: pass
        
        # Fit Stats Row
        avg_fit = {
            'Outcome': outcome, 'Alpha': outcome_alpha, 'Method': mode,
            'N_Obs': int(np.mean([f['n_obs'] for f in fit_stats_list])),
            'QIC': np.mean([f['qic'] for f in fit_stats_list])
        }
        all_fit_stats.append(avg_fit)

    # --- COMPILING ---
    summary_df = pd.DataFrame(summary_results)
    fit_stats_df = pd.DataFrame(all_fit_stats)
    
    # Create Formatted Summary (The Pretty One)
    display_summary = summary_df.copy()
    display_summary['IRR'] = display_summary['IRR'].apply(lambda x: f"{x:.3f}")
    display_summary['95% CI'] = display_summary.apply(
        lambda r: f"{r['CI_Low']:.3f}-{r['CI_High']:.3f}", axis=1
    )
    display_summary['P-value'] = display_summary.apply(
        lambda r: f"{r['P_Value']:.4f}{r['Significant']}", axis=1
    )
    display_summary = display_summary[['Outcome', 'IRR', '95% CI', 'P-value', 'Method']]
    
    # Display
    if verbose:
        print(f"\nSUMMARY RESULTS ({mode}):")
        display(display_summary)
        
    # Saving
    if save_tables:
        # 1. Summary (Raw)
        summary_df.to_csv(os.path.join(results_dir, f'gee_{mode_label}_summary.csv'), index=False)
        # 2. Summary (Formatted) - RESTORED
        display_summary.to_csv(os.path.join(results_dir, f'gee_{mode_label}_summary_formatted.csv'), index=False)
        # 3. Fit Stats
        fit_stats_df.to_csv(os.path.join(results_dir, f'gee_{mode_label}_fit_statistics.csv'), index=False)
        # 4. Combined Coefficients
        pd.concat(all_coefficients.values()).to_csv(os.path.join(results_dir, f'gee_{mode_label}_all_coefficients.csv'))
        
        # 5. Individual Coefficients - RESTORED
        for outcome, coef_df in all_coefficients.items():
            safe_outcome = outcome.replace(' ', '_')
            coef_df.to_csv(os.path.join(results_dir, f'gee_{mode_label}_coefficients_{safe_outcome}.csv'), index=False)

        # 6. MICE Diagnostics - RESTORED
        if mode == 'MICE' and all_mice_diagnostics:
            pd.concat(all_mice_diagnostics, ignore_index=True).to_csv(
                os.path.join(results_dir, f'gee_{mode_label}_mice_diagnostics.csv'), index=False
            )

        if verbose: print(f"   ✓ Saved all tables to {results_dir}")

    # Diagnostics Trigger
    if save_diagnostics and all_models:
        df_diag = datasets[0].copy() # Use Dataset 0
        if 'pct_remote_10' not in df_diag.columns:
            df_diag['pct_remote_10'] = df_diag['pct_remote'] * 10
            
        _save_dharma_diagnostics(
            models_dict=all_models, df=df_diag, alpha_dict=alpha_dict,
            save_dir=results_dir, mode_label=mode_label, 
            verbose=verbose, show_plots=show_plots
        )

    return {
        'summary': summary_df,
        'coefficients': all_coefficients,
        'fit_stats': fit_stats_df,
        'models': all_models
    }


# ==============================================================================
# 3. INTERACTION TESTS (Pooled for MICE)
# ==============================================================================
def run_interaction_tests(
    data, outcomes, formula_template,
    interaction_vars, exposure_var='pct_remote_10',
    alpha_dict=None, results_dir=None, verbose=True
):
    """
    Runs pooled interaction tests by calling run_gee_analysis iteratively.
    """
    if isinstance(data, dict): 
        mode_label = 'MICE'
    else: 
        mode_label = 'CC'
    
    if verbose:
        print(f"\n{'='*60}\nINTERACTION TESTS ({mode_label}) | Pooling: {'Yes' if mode_label=='MICE' else 'No'}\n{'='*60}")
    
    pooled_results = []
    
    for outcome in outcomes:
        if verbose: print(f"Testing interactions for: {outcome}...")
        for int_var in interaction_vars:
            # Construct Formula
            if int_var in ['ethnicity', 'sex'] and 'C(' not in int_var:
                term = f"C({int_var})"
            else:
                term = int_var
            
            # Add interaction term
            int_formula = f"{formula_template} + {exposure_var}:{term}"
            
            # Run GEE (Quietly)
            try:
                res = run_gee_analysis(
                    data=data, outcomes=[outcome], formula_template=int_formula,
                    exposure_var=exposure_var, alpha_dict=alpha_dict,
                    verbose=False, full_output=True, save_tables=False, 
                    save_diagnostics=False 
                )
                
                # Extract Interaction Row
                coefs = res['coefficients'][outcome]
                # Find row with colon and exposure
                int_row = coefs[
                    coefs['Coefficient'].str.contains(':') & 
                    coefs['Coefficient'].str.contains(exposure_var)
                ]
                
                for _, row in int_row.iterrows():
                    pooled_results.append({
                        'Outcome': outcome, 'Interaction': int_var, 'Term': row['Coefficient'],
                        'IRR': row['IRR'], 'CI_Low': row['IRR_CI_Low'], 'CI_High': row['IRR_CI_High'],
                        'P_Value': row['P_Value'], 'Significant': row['Significant']
                    })
            except Exception as e:
                if verbose: print(f"   Failed {int_var}: {e}")

    # Save
    df_int = pd.DataFrame(pooled_results)
    if results_dir:
        # Conditional Filename
        if mode_label == 'MICE':
            fname = f'gee_{mode_label}_interactions_pooled.csv'
        else:
            fname = f'gee_{mode_label}_interactions.csv'
            
        save_path = os.path.join(results_dir, fname)
        df_int.to_csv(save_path, index=False)
        if verbose: print(f"\n✓ Saved interaction results to: {save_path}")
        
    if verbose: display(df_int)
    return df_int


# ==============================================================================
# 4. MASTER PIPELINE
# ==============================================================================
def run_full_analysis_pipeline(
    data, outcomes, formula_template, alpha_dict, results_dir, interaction_vars,
    run_main=True, run_diagnostics=True, run_interactions=True, verbose=True
):
    """
    Master wrapper that runs regression, diagnostics, and interactions in one go.
    """
    # 1. Main Regression & Diagnostics
    if run_main:
        results = run_gee_analysis(
            data=data,
            outcomes=outcomes,
            formula_template=formula_template,
            alpha_dict=alpha_dict,
            verbose=verbose,
            save_tables=True,
            results_dir=results_dir,
            save_diagnostics=run_diagnostics, 
            show_plots=False # Always suppress screen plots for pipeline
        )
    
    # 2. Interactions
    if run_interactions:
        run_interaction_tests(
            data=data,
            outcomes=outcomes,
            formula_template=formula_template,
            interaction_vars=interaction_vars,
            alpha_dict=alpha_dict,
            results_dir=results_dir,
            verbose=verbose
        )
    
    if verbose:
        print("\n" + "="*60)
        print("PIPELINE COMPLETE")
        print("="*60)

### Full run

In [None]:
# ==============================================================================
# CONFIGURATION & CONSTANTS
# ==============================================================================

# Outcomes
OUTCOMES = [
    'outcome_emergencies', 
    'outcome_admissions', 
    'outcome_bed_days', 
    'outcome_mha_admissions', 
    # 'outcome_long_stay_admissions' 
]

# Missingness targets
MISSINGNESS_TARGET_VARS = [
    'ethnicity',
    'imd_at_exposure',
    'sex',
    'pct_remote'
]

# Missingness variables to compare
MISSINGNESS_COMPARE_VARS = [
    'age_at_exposure', 'sex', 'ethnicity', 'imd_at_exposure',
    'period_id', 'total_consults_rate', 'pct_remote',
    'missing_modality_consults_rate', 'died_in_outcome',
    'anx_ever', 'dep_ever', 'smi_ever',
    # 'af_ever', 'hf_ever', 'ihd_ever'
]

# Imputation columns
IMPUTATION_PREDICTORS = [
    # --- TARGETS ---
    'sex', 'ethnicity', 'imd_at_exposure', 'pct_remote',
    # --- PREDICTORS ---
    'age_at_exposure',
    'total_consults', 'total_consults_rate',
    'remote_consults', 'remote_consults_rate',
    'f2f_consults', 'f2f_consults_rate',
    'missing_modality_consults', 'missing_modality_consults_rate',
    'active_days_in_exposure',
    'outcome_days_at_risk',
    'died_in_outcome',
    'anx_ever', 'dep_ever', 'smi_ever', 
    # 'af_ever', 'hf_ever', 'ihd_ever'
]

# Dynamically build the full list of columns to impute
# This ensures new outcomes and their rates are added automatically
for outcome in OUTCOMES:
    # missingness outcomes
    # MISSINGNESS_COMPARE_VARS.append(outcome) ### NOT USING ABSOLUTE NUMBERS FOR MISSINGNESS
    MISSINGNESS_COMPARE_VARS.append(f"{outcome}_rate")
    # imputation outcomes
    IMPUTATION_PREDICTORS.append(outcome)
    IMPUTATION_PREDICTORS.append(f"{outcome}_rate")

BASELINE_CHARACTERISTICS_VARS = [
    'age_at_exposure', 'sex', 'ethnicity', 'imd_at_exposure',
    'total_consults', 'pct_remote',
    'anx_ever', 'dep_ever', 'smi_ever', 
    # 'af_ever', 'hf_ever', 'ihd_ever'
]

# Covariates and Formula
BASE_COVARIATES_STR = (
    "pct_remote_10 + total_consults + age_at_exposure + "
    "imd_at_exposure + C(period_id, Treatment(1)) + C(sex, Treatment('M')) + "
    "C(ethnicity, Treatment('White')) + "
    "smi_ever + anx_ever + dep_ever"
    # " + ihd_ever + hf_ever + af_ever"
)

# The full formula template for GEE
FORMULA_TEMPLATE = f"{{outcome}} ~ {BASE_COVARIATES_STR}"

# Variables to test for interactions with exposure
INTERACTION_VARS = [
    "C(ethnicity, Treatment('White'))",
    'imd_at_exposure',
    'age_at_exposure',
    "C(sex, Treatment('M'))"
]

# ==============================================================================
# ANALYSIS PIPELINE
# ==============================================================================

# Define periods
periods_df = generate_and_visualize_periods(
    START_DATE, 
    END_DATE_LDN, 
    END_DATE_SLAM, 
    EXPOSURE_LENGTH, 
    OUTCOME_LENGTH, 
    WINDOW_STEP,
    PERIODS_START,
    visualize=False,
    save_png=True,  # New parameter
    results_dir=f"{RESULTS_DIR}/period_definition/"
)

# Summary of initial cohort
save_summary_statistics(
    verbose=False,
    results_dir=f"{RESULTS_DIR}/descriptive",
    ldn_dates=(periods_df.exposure_start.min(), periods_df.exposure_end.max()),
    slam_dates=(periods_df.exposure_start.min(), periods_df.outcome_end.max()) ### MAYBE USE OUTCOME_START INSTEAD
)

# Populate periods with data
final_df = generate_cohort(
    demos,
    consults,
    emergencies,
    admissions,
    demos_dynamic,
    periods_df,
    extraction_date=EXTRACTION_DATE,
    use_single_random_period=True,
    sample_after_eligibility=True,
    random_seed=2025,
    verbose=False,
    long_stay_threshold_days=7 ### OPTIONAL: ONLY IF USING LONG STAY OUTCOME
)

# Normality and parametric testing
print('Checking distributions...')
distribution_checks = check_distribution_assumptions(
    final_df,
    OUTCOMES,
    results_dir=f"{RESULTS_DIR}/distribution_checks",
    verbose=False
)

# Missingness analysis
print('Checking overall missingness...')
missingness_results = run_missingness_analysis(
    final_df,
    target_vars=MISSINGNESS_TARGET_VARS,
    columns_to_compare=MISSINGNESS_COMPARE_VARS,
    verbose=False,
    results_dir=f"{RESULTS_DIR}/missingness",    
)
print('Checking GP missingness...')
gp_missingness_results = analyze_gp_missingness_grid(
    final_df,
    MISSINGNESS_TARGET_VARS,
    verbose=False,
    results_dir=f"{RESULTS_DIR}/missingness",
)

# Missing data and imputation
print('Creating complete case dataset...')
df_complete_case = handle_missing_data(
    df=final_df,
    cols_to_impute=IMPUTATION_PREDICTORS, 
    targets_to_update=MISSINGNESS_TARGET_VARS,
    complete_case=True,
    keep_pct_remote_nan=False, 
    verbose=False,
    results_dir=f"{RESULTS_DIR}/imputation"
)

print('Creating imputed datasets...')
imputed_dict = handle_missing_data(
    df=final_df,
    cols_to_impute=IMPUTATION_PREDICTORS,
    targets_to_update=MISSINGNESS_TARGET_VARS,
    complete_case=False,
    keep_pct_remote_nan=False,
    verbose=False,
    n_datasets=5,
    n_iterations=25,
    mice_cycles=5,      
    results_dir=f"{RESULTS_DIR}/imputation"
)

# Create baseline characteristics tables
print('Generating table one: initial')
table_one_initial = generate_table_one(
    final_df, 
    f"{RESULTS_DIR}/descriptive", 
    filename='baseline_characteristics_initial_cohort.csv', 
    verbose=False,
    groupby_mode='period',
    show_pval=False,
    remote_as_continuous=True,
    selected_columns=BASELINE_CHARACTERISTICS_VARS
)
print('Generating table one: imputed (sample)')
table_one_imputed = generate_table_one(
    imputed_dict[0], 
    f"{RESULTS_DIR}/descriptive", 
    filename='baseline_characteristics_imputed_cohort.csv', 
    verbose=False,
    groupby_mode='period',
    show_pval=False,
    remote_as_continuous=True,
    selected_columns=BASELINE_CHARACTERISTICS_VARS
)

# Alpha optimisation
print('Optimising alphas: complete case...')
optimal_alphas_cc = optimize_gee_alpha(
    df=df_complete_case,
    outcomes=OUTCOMES,
    alphas_to_test=[0.05, 0.1, 0.2, 0.5, 1.0, 1.5, 2.0, 3.0],
    formula_template=FORMULA_TEMPLATE,
    save_dir=f"{RESULTS_DIR}/regression",
    verbose=True
)
print('Optimising alphas: imputed...')
optimal_alphas_mice = optimize_gee_alpha(
    df=imputed_dict,
    outcomes=OUTCOMES,
    alphas_to_test=[0.05, 0.1, 0.2, 0.5, 1.0, 1.5, 2.0, 3.0],
    formula_template=FORMULA_TEMPLATE,
    save_dir=f"{RESULTS_DIR}/regression",
    verbose=True
)

# Regression
print("Performing regression: complete case...")
run_full_analysis_pipeline(
    data=df_complete_case,
    outcomes=OUTCOMES,
    formula_template=FORMULA_TEMPLATE,
    alpha_dict=optimal_alphas_cc,
    results_dir=f"{RESULTS_DIR}/regression",
    interaction_vars=INTERACTION_VARS,
    run_main=True,
    run_diagnostics=True,     
    run_interactions=True,   
    verbose=True    
)
print("Performing regression: imputed...")
run_full_analysis_pipeline(
    data=imputed_dict,
    outcomes=OUTCOMES,
    formula_template=FORMULA_TEMPLATE,
    alpha_dict=optimal_alphas_mice,
    results_dir=f"{RESULTS_DIR}/regression",
    interaction_vars=INTERACTION_VARS,
    run_main=True,
    run_diagnostics=True,     
    run_interactions=True,
    verbose=True         
)

print('Finished!')

### Save all outputs and code to combined files

In [None]:
# Supplementary Materials Compiler for Jupyter Notebook
# Run this cell after defining RESULTS_DIR
# Install python-docx if needed (uncomment if not installed)
# !pip install python-docx pandas
import os
from pathlib import Path
import pandas as pd
from docx import Document
from docx.shared import Inches, Pt
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.oxml.ns import qn
from docx.oxml import OxmlElement
# ============================================================================
# OUTPUT CONFIGURATION - Choose which files to create
# ============================================================================
CREATE_SUPPLEMENTARY_MATERIALS = True   # Create Word doc with all results
EXPORT_NOTEBOOKS_AS_IPYNB = True        # Export cleaned .ipynb files (outputs removed)
EXPORT_NOTEBOOKS_AS_DOCX = True         # Export notebook code as Word docs
# ============================================================================
# ============================================================================
# FILES TO INCLUDE - Customize which files appear in supplementary materials
# Format: ['filename', 'Display Name for Appendix', Include (True/False)]
# Files will appear in the order listed below, grouped by folder
# ============================================================================
FILES_TO_INCLUDE = [
    # period_definition
    ['periods.csv', 'Study Period Details', True],
    ['timeline_visualization.txt', 'Timeline Visualization (Text)', False],
    ['timeline_visualization.png', 'Timeline Visualization', True],
    
    # descriptive
    ['summary_statistics.txt', 'Summary Statistics', True],
    ['baseline_characteristics_initial_cohort.csv', 'Baseline Characteristics by Period (Initial Cohort)', True],
    ['baseline_characteristics_imputed_cohort.csv', 'Baseline Characteristics by Period (Sample Imputed Cohort)', True],
    
    # distribution_checks
    ['distribution_assumptions.csv', 'Distribution Assumptions', True],
    
    # missingness
    ['missingness_summary_counts.csv', 'Missingness Summary Counts', True],
    ['missingness_table_ethnicity.csv', 'Missingness by Ethnicity', True],
    ['missingness_table_imd_at_exposure.csv', 'Missingness by IMD at Exposure', True],
    ['missingness_table_sex.csv', 'Missingness by Sex', True],
    ['missingness_table_pct_remote.csv', 'Missingness by Percent Remote', True],
    ['gp_missingness_grid.png', 'GP Missingness Grid', True],
    
    # imputation
    ['complete_case_summary.txt', 'Complete Case Summary', True],
    ['variables_used.txt', 'Variables Used in Imputation', True],
    ['imputation_stats_ethnicity.csv', 'Imputation Statistics by Ethnicity', True],
    ['imputation_stats_imd_at_exposure.csv', 'Imputation Statistics by IMD at Exposure', True],
    ['imputation_stats_pct_remote.csv', 'Imputation Statistics by Percent Remote', True],
    ['imputation_stats_sex.csv', 'Imputation Statistics by Sex', True],
    
    # regression
    ['gee_CC_summary.csv', 'GEE Complete Case Summary', False],
    ['gee_CC_summary_formatted.csv', 'GEE Complete Case Summary (Formatted)', True],
    ['gee_CC_coefficients_outcome_emergencies.csv', 'GEE Complete Case Coefficients: Emergencies', True],
    ['gee_CC_coefficients_outcome_admissions.csv', 'GEE Complete Case Coefficients: Admissions', True],
    ['gee_CC_coefficients_outcome_bed_days.csv', 'GEE Complete Case Coefficients: Bed Days', True],
    ['gee_CC_coefficients_outcome_mha_admissions.csv', 'GEE Complete Case Coefficients: MHA Admissions', True],
    ['gee_CC_fit_statistics.csv', 'GEE Complete Case Fit Statistics', True],
    ['gee_CC_alpha_optimization_results.csv', 'GEE Complete Case Alpha Optimization Results', True],
    ['gee_CC_all_coefficients.csv', 'GEE Complete Case All Coefficients', False],
    ['gee_CC_dharma_qq_plots.png', 'GEE Complete Case DHARMa QQ Plots', True],
    ['gee_CC_dharma_residuals_vs_predicted.png', 'GEE Complete Case DHARMa Residuals vs Predicted', True],
    ['gee_CC_interactions.csv', 'GEE Complete Case Interactions', True],
    ['gee_MICE_summary.csv', 'GEE MICE Summary', False],
    ['gee_MICE_summary_formatted.csv', 'GEE MICE Summary (Formatted)', True],
    ['gee_MICE_coefficients_outcome_emergencies.csv', 'GEE MICE Coefficients: Emergencies', True],
    ['gee_MICE_coefficients_outcome_admissions.csv', 'GEE MICE Coefficients: Admissions', True],
    ['gee_MICE_coefficients_outcome_bed_days.csv', 'GEE MICE Coefficients: Bed Days', True],
    ['gee_MICE_coefficients_outcome_mha_admissions.csv', 'GEE MICE Coefficients: MHA Admissions', True],
    ['gee_MICE_fit_statistics.csv', 'GEE MICE Fit Statistics', True],
    ['gee_MICE_alpha_optimization_results_pooled.csv', 'GEE MICE Alpha Optimization Results (Pooled)', True],
    ['gee_MICE_all_coefficients.csv', 'GEE MICE All Coefficients', False],
    ['gee_MICE_dharma_qq_plots.png', 'GEE MICE DHARMa QQ Plots', True],
    ['gee_MICE_dharma_residuals_vs_predicted.png', 'GEE MICE DHARMa Residuals vs Predicted', True],
    ['gee_MICE_mice_diagnostics.csv', 'GEE MICE Diagnostics', True],
    ['gee_MICE_interactions_pooled.csv', 'GEE MICE Interactions (Pooled)', True],
]
# ============================================================================
# ============================================================================
# FORMATTING PARAMETERS - Edit these to customize the document appearance
# ============================================================================
FONT_NAME = 'Arial'          # Font family for all text
FONT_SIZE = 10               # Font size in points
IMAGE_WIDTH = 6.0            # Image width in inches
TABLE_STYLE = 'Light Grid Accent 1'  # Word table style
# ============================================================================
def add_table_border(table):
    """Add borders to a table"""
    tbl = table._element
    tblPr = tbl.tblPr
    if tblPr is None:
        tblPr = OxmlElement('w:tblPr')
        tbl.insert(0, tblPr)
    tblBorders = OxmlElement('w:tblBorders')
    for border_name in ['top', 'left', 'bottom', 'right', 'insideH', 'insideV']:
        border = OxmlElement(f'w:{border_name}')
        border.set(qn('w:val'), 'single')
        border.set(qn('w:sz'), '4')
        border.set(qn('w:space'), '0')
        border.set(qn('w:color'), '000000')
        tblBorders.append(border)
    tblPr.append(tblBorders)
def format_value(val):
    """Format values to 3 decimal places if float, remove 'outcome_' prefix from strings"""
    if pd.isna(val):
        return ''
    if isinstance(val, float):
        return f'{val:.3f}'
    # Remove 'outcome_' prefix from string values
    val_str = str(val)
    if val_str.startswith('outcome_'):
        val_str = val_str[8:]  # len('outcome_') = 8
    return val_str
def clean_column_name(col_name):
    """Clean column names - return empty string for blank or 'Unnamed' columns, remove 'outcome_' prefix"""
    col_str = str(col_name).strip()
    if not col_str or 'Unnamed' in col_str or col_str == 'nan':
        return ''
    # Remove 'outcome_' prefix from column names
    if col_str.startswith('outcome_'):
        col_str = col_str[8:]  # len('outcome_') = 8
    return col_str
# Create document
if CREATE_SUPPLEMENTARY_MATERIALS:
    doc = Document()
    title = doc.add_heading('Supplementary Materials', level=0)
    title.alignment = WD_ALIGN_PARAGRAPH.CENTER
    doc.add_paragraph()
# Process each subfolder as an appendix
if CREATE_SUPPLEMENTARY_MATERIALS:
    results_path = Path(RESULTS_DIR)
    subdirs = sorted([d for d in results_path.iterdir() if d.is_dir() and d.name != '.ipynb_checkpoints'])
    
    print(f"Processing supplementary materials from {len(subdirs)} folders...")
    
    # Build lookup dictionary from FILES_TO_INCLUDE
    file_config = {}
    for idx, (filename, display_name, include) in enumerate(FILES_TO_INCLUDE):
        file_config[filename] = {
            'display_name': display_name,
            'include': include,
            'order': idx,
            'found': False,
            'folder': None
        }
    
    # Find all files across all subdirectories
    all_files_in_dirs = {}
    for subdir in subdirs:
        for ext in ['*.csv', '*.txt', '*.png', '*.jpg', '*.jpeg']:
            for file_path in subdir.glob(ext):
                if file_path.name not in all_files_in_dirs:
                    all_files_in_dirs[file_path.name] = []
                all_files_in_dirs[file_path.name].append(file_path)
    
    # Match configured files to actual files
    for filename in file_config.keys():
        if filename in all_files_in_dirs:
            if len(all_files_in_dirs[filename]) > 1:
                print(f"⚠ Warning: '{filename}' found in multiple folders:")
                for fp in all_files_in_dirs[filename]:
                    print(f"    - {fp.parent.name}")
                print(f"    Using first occurrence: {all_files_in_dirs[filename][0].parent.name}")
            
            file_path = all_files_in_dirs[filename][0]
            file_config[filename]['found'] = True
            file_config[filename]['folder'] = file_path.parent.name
            file_config[filename]['path'] = file_path
    
    # Check for missing files
    missing_files = [f for f, cfg in file_config.items() if not cfg['found'] and cfg['include']]
    if missing_files:
        print(f"\n⚠ Warning: {len(missing_files)} file(s) in FILES_TO_INCLUDE not found:")
        for f in missing_files:
            print(f"    - {f}")
    
    # Check for unlisted files
    configured_files = set(file_config.keys())
    found_files = set(all_files_in_dirs.keys())
    unlisted_files = found_files - configured_files
    if unlisted_files:
        print(f"\n⚠ Warning: {len(unlisted_files)} file(s) found but not in FILES_TO_INCLUDE:")
        for f in sorted(unlisted_files):
            folder = all_files_in_dirs[f][0].parent.name
            print(f"    - {f} (in {folder})")
    
    # Group files by folder and order by FILES_TO_INCLUDE position
    files_by_folder = {}
    folder_first_appearance = {}  # Track when each folder first appears in FILES_TO_INCLUDE
    
    for filename, cfg in file_config.items():
        if cfg['found'] and cfg['include']:
            folder = cfg['folder']
            if folder not in files_by_folder:
                files_by_folder[folder] = []
                folder_first_appearance[folder] = cfg['order']  # Track order from FILES_TO_INCLUDE
            files_by_folder[folder].append({
                'filename': filename,
                'display_name': cfg['display_name'],
                'path': cfg['path'],
                'order': cfg['order']
            })
    
    # Sort files within each folder by their order in FILES_TO_INCLUDE
    for folder in files_by_folder:
        files_by_folder[folder].sort(key=lambda x: x['order'])
    
    # Sort folders by their first appearance in FILES_TO_INCLUDE (not alphabetically)
    folder_names = sorted(files_by_folder.keys(), key=lambda f: folder_first_appearance[f])
    
    print(f"\nIncluding {sum(len(files) for files in files_by_folder.values())} file(s) across {len(folder_names)} folder(s)")
    
    # Process each folder as an appendix
    for appendix_num, folder_name in enumerate(folder_names, start=1):
        print(f"\nAppendix {appendix_num}: {folder_name}")
        
        # Add page break and appendix heading
        doc.add_page_break()
        appendix_title = f"Appendix {appendix_num}: {folder_name.replace('_', ' ').title()}"
        doc.add_heading(appendix_title, level=1)
        
        files_in_folder = files_by_folder[folder_name]
        print(f"  Including {len(files_in_folder)} file(s) (ordered by FILES_TO_INCLUDE)")
        
        # Process each file with sub-numbering
        for file_idx, file_info in enumerate(files_in_folder):
            # Create sub-number (1, 2, 3, etc.)
            sub_number = f"{appendix_num}.{file_idx + 1}"
            file_path = file_info['path']
            full_title = f"Appendix {sub_number}: {file_info['display_name']}"
            
            # Process based on file type
            if file_path.suffix.lower() == '.csv':
                try:
                    df = pd.read_csv(file_path)
                    doc.add_heading(full_title, level=2)
                    
                    # Truncate if too large
                    if len(df) > 100 or len(df.columns) > 15:
                        note_para = doc.add_paragraph(f"Note: Table truncated. Original: {len(df)} rows × {len(df.columns)} cols")
                        for run in note_para.runs:
                            run.font.name = FONT_NAME
                            run.font.size = Pt(FONT_SIZE)
                            run.font.italic = True
                        df = df.iloc[:100, :15]
                    
                    # Create table
                    table = doc.add_table(rows=len(df) + 1, cols=len(df.columns))
                    table.style = TABLE_STYLE
                    
                    # Header
                    for i, col in enumerate(df.columns):
                        cell = table.rows[0].cells[i]
                        cell.text = clean_column_name(col)
                        for paragraph in cell.paragraphs:
                            for run in paragraph.runs:
                                run.font.bold = True
                                run.font.name = FONT_NAME
                                run.font.size = Pt(FONT_SIZE)
                    
                    # Data with 3 decimal place formatting
                    for i, row in enumerate(df.itertuples(index=False), start=1):
                        for j, val in enumerate(row):
                            cell = table.rows[i].cells[j]
                            cell.text = format_value(val)
                            for paragraph in cell.paragraphs:
                                for run in paragraph.runs:
                                    run.font.name = FONT_NAME
                                    run.font.size = Pt(FONT_SIZE)
                    
                    add_table_border(table)
                    doc.add_paragraph()
                    print(f"  ✓ {sub_number} Table: {file_path.name}")
                except Exception as e:
                    print(f"  ✗ {sub_number} Error with {file_path.name}: {e}")
            
            elif file_path.suffix.lower() == '.txt':
                try:
                    doc.add_heading(full_title, level=2)
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                    p = doc.add_paragraph()
                    run = p.add_run(content)
                    run.font.name = FONT_NAME
                    run.font.size = Pt(FONT_SIZE)
                    doc.add_paragraph()
                    print(f"  ✓ {sub_number} Text: {file_path.name}")
                except Exception as e:
                    print(f"  ✗ {sub_number} Error with {file_path.name}: {e}")
            
            elif file_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                try:
                    doc.add_heading(full_title, level=2)
                    doc.add_picture(str(file_path), width=Inches(IMAGE_WIDTH))
                    doc.paragraphs[-1].alignment = WD_ALIGN_PARAGRAPH.CENTER
                    doc.add_paragraph()
                    print(f"  ✓ {sub_number} Image: {file_path.name}")
                except Exception as e:
                    print(f"  ✗ {sub_number} Error with {file_path.name}: {e}")
# Save supplementary materials document
if CREATE_SUPPLEMENTARY_MATERIALS:
    output_path = Path(RESULTS_DIR) / 'Supplementary_Materials.docx'
    doc.save(str(output_path))
    print(f"\n{'='*50}")
    print(f"✓ COMPLETE! Saved to: {output_path}")
    print(f"  {len(folder_names)} appendices processed")
    print(f"  {sum(len(files) for files in files_by_folder.values())} files included")
    print(f"{'='*50}")
# ============================================================================
# NOTEBOOK EXPORT (with outputs removed for patient data privacy)
# ============================================================================
if EXPORT_NOTEBOOKS_AS_IPYNB or EXPORT_NOTEBOOKS_AS_DOCX:
    print(f"\n{'='*50}")
    print("Exporting notebook code...")
    print(f"{'='*50}")
    try:
        import json
        import glob
        
        # Find ALL .ipynb files in current directory (excluding checkpoints)
        ipynb_files = glob.glob('*.ipynb')
        ipynb_files = [f for f in ipynb_files if '.ipynb_checkpoints' not in f]
        
        if not ipynb_files:
            print("\n⚠ No notebook files found in current directory")
        else:
            print(f"\nFound {len(ipynb_files)} notebook(s) to export")
            
            for nb_filename in ipynb_files:
                nb_path = Path(nb_filename)
                nb_name = nb_path.stem
                print(f"\n  Processing: {nb_name}")
                
                try:
                    # Read notebook
                    with open(nb_path, 'r', encoding='utf-8') as f:
                        nb_data = json.load(f)
                    
                    # Remove all outputs
                    for cell in nb_data.get('cells', []):
                        if cell.get('cell_type') == 'code':
                            cell['outputs'] = []
                            cell['execution_count'] = None
                    
                    # Save cleaned .ipynb file (if enabled)
                    if EXPORT_NOTEBOOKS_AS_IPYNB:
                        clean_ipynb_path = Path(RESULTS_DIR) / f'{nb_name}_cleaned.ipynb'
                        with open(clean_ipynb_path, 'w', encoding='utf-8') as f:
                            json.dump(nb_data, f, indent=2)
                        print(f"    ✓ Saved cleaned notebook: {clean_ipynb_path}")
                    
                    # Create Word document with code (if enabled)
                    if EXPORT_NOTEBOOKS_AS_DOCX:
                        code_doc = Document()
                        code_title = code_doc.add_heading('Notebook Code', level=0)
                        code_title.alignment = WD_ALIGN_PARAGRAPH.CENTER
                        code_doc.add_paragraph()
                        
                        # Add notebook name subtitle
                        subtitle = code_doc.add_heading(nb_name, level=1)
                        subtitle.alignment = WD_ALIGN_PARAGRAPH.CENTER
                        code_doc.add_paragraph()
                        
                        for cell_idx, cell in enumerate(nb_data.get('cells', []), start=1):
                            cell_type = cell.get('cell_type')
                            source = ''.join(cell.get('source', []))
                            
                            if not source.strip():
                                continue
                            
                            if cell_type == 'markdown':
                                # Add markdown cells as regular text
                                code_doc.add_heading(f'Markdown Cell {cell_idx}', level=2)
                                para = code_doc.add_paragraph(source)
                                for run in para.runs:
                                    run.font.name = FONT_NAME
                                    run.font.size = Pt(FONT_SIZE)
                            
                            elif cell_type == 'code':
                                # Add code cells in monospace
                                code_doc.add_heading(f'Code Cell {cell_idx}', level=2)
                                para = code_doc.add_paragraph()
                                run = para.add_run(source)
                                run.font.name = 'Courier New'
                                run.font.size = Pt(9)
                                # Light gray background for code
                                from docx.oxml import parse_xml
                                shading_elm = parse_xml(r'<w:shd {} w:fill="F0F0F0"/>'.format('xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main"'))
                                para._element.get_or_add_pPr().append(shading_elm)
                            
                            code_doc.add_paragraph()  # Spacing between cells
                        
                        # Save code document
                        code_doc_path = Path(RESULTS_DIR) / f'{nb_name}_code.docx'
                        code_doc.save(str(code_doc_path))
                        print(f"    ✓ Saved code document: {code_doc_path}")
                
                except Exception as e:
                    print(f"    ✗ Error processing {nb_name}: {e}")
        
    except ImportError as e:
        print(f"\n⚠ Import error: {e}")
        print("  Make sure required packages are installed")
    except Exception as e:
        print(f"\n⚠ Error during notebook export: {e}")
        print("  Skipping notebook export")
else:
    print(f"\n{'='*50}")
    print("Notebook export disabled in configuration")
    print(f"{'='*50}")
print(f"\n{'='*50}")
print("✓ ALL TASKS COMPLETE!")
print(f"{'='*50}")