In [211]:
import os

path = os.getcwd()
print("Current directory:", path)

Current directory: D:\Work\PTSD\Linear


# DML:

In [212]:
import os
import warnings

warnings.filterwarnings("ignore")  # optional

# Option 1 (recommended)
new_path = r"D:\Work\PTSD\DML"

os.chdir(new_path)  # Change working directory
print("Changed to:", os.getcwd())


Changed to: D:\Work\PTSD\DML


In [213]:
import pandas as pd
df = pd.read_csv(r"data_baseline.csv")

In [214]:

# Step 1: Identify columns by prefix
cat_cols = [col for col in df.columns if col.startswith('CAT_')]
subcat_cols = [col for col in df.columns if col.startswith('SUBCAT_')]
subsubcat_cols = [col for col in df.columns if col.startswith('SubSubCat_')]

# Step 2: Print counts
print(f"Number of CAT_ columns: {len(cat_cols)}")
print(f"Number of SUBCAT_ columns: {len(subcat_cols)}")
print(f"Number of SubSubCat_ columns: {len(subsubcat_cols)}")

# Step 3: Column-wise sums
cat_sums = df[cat_cols].sum().sort_values(ascending=False)
subcat_sums = df[subcat_cols].sum().sort_values(ascending=False)
subsubcat_sums = df[subsubcat_cols].sum().sort_values(ascending=False)

# Step 4: Convert to DataFrames
cat_df = cat_sums.reset_index()
cat_df.columns = ['Column', 'Sum']

subcat_df = subcat_sums.reset_index()
subcat_df.columns = ['Column', 'Sum']

subsubcat_df = subsubcat_sums.reset_index()
subsubcat_df.columns = ['Column', 'Sum']

# Step 5: Write to Excel
with pd.ExcelWriter("category_column_sums.xlsx") as writer:
    cat_df.to_excel(writer, sheet_name="CAT_Sums", index=False)
    subcat_df.to_excel(writer, sheet_name="SUBCAT_Sums", index=False)
    subsubcat_df.to_excel(writer, sheet_name="SubSubCat_Sums", index=False)

print("✅ Exported to 'category_column_sums.xlsx'")


Number of CAT_ columns: 23
Number of SUBCAT_ columns: 42
Number of SubSubCat_ columns: 243
✅ Exported to 'category_column_sums.xlsx'


## CAT analysis:

In [215]:
# Import all necessary packages
import pandas as pd
import numpy as np
import re
# For visualization and future steps
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.experimental import enable_iterative_imputer  # Needed to enable the experimental feature
from sklearn.impute import IterativeImputer


In [216]:
# Check basic structure
print("Shape of dataset:", df.shape)
print("\nSample columns:", df.columns.tolist()[:10])
print("\nMissing values:\n", df.isnull().sum().sort_values(ascending=False).head(10))

# Remove duplicate rows
df = df.drop_duplicates()

# Confirm shape after removing duplicates
print("Shape after removing duplicates:", df.shape)


Shape of dataset: (6125, 465)

Sample columns: ['CIN5', 'StartDatum', 'BEH_MOD', 'BEHDAGEN_GEPLAND', 'AANTAL_PCL', 'TOESTWO', 'BEH_AFG', 'TK', 'MM_CAPS_IN', 'MM_CAPS_TK']

Missing values:
 instrument_SDV_IN    6125
Eaantal_TK           6125
Dcriterium_FU        6125
Cernst_FU            6125
Caantal_FU           6125
Ccriterium_FU        6125
Bernst_FU            6125
Baantal_FU           6125
Bcriterium_FU        6125
Eernst_TK            6125
dtype: int64
Shape after removing duplicates: (6125, 465)


In [217]:
import pyreadstat

# Load gender info
gender_df, meta = pyreadstat.read_sav("SDV_IN_Gender_2019_2024.sav")

# Just extract SDV_SEXE column and append to df
df["SDV_SEXE"] = gender_df["SDV_SEXE"].reset_index(drop=True)

# Optional: map to labels
gender_map = {1.0: "Male", 2.0: "Female", 3.0: "Other"}
df["gender_label"] = df["SDV_SEXE"].map(gender_map)

# Done! Check a sample
print(df[["SDV_SEXE", "gender_label"]].value_counts())


SDV_SEXE  gender_label
2.0       Female          4602
1.0       Male            1475
3.0       Other             48
Name: count, dtype: int64


In [218]:
# Assuming 'gender' and 'SDV_SEXE' columns are in df

# Gender dummy variables
df['gender_1'] = (df['gender'] == 1).astype(int)
df['gender_2'] = (df['gender'] == 2).astype(int)

# SDV_SEXE dummy variables
df['SDV_SEXE_1'] = (df['SDV_SEXE'] == 1).astype(int)
df['SDV_SEXE_2'] = (df['SDV_SEXE'] == 2).astype(int)
df['SDV_SEXE_3'] = (df['SDV_SEXE'] == 3).astype(int)

# Create binary columns
df['ethnicity_Dutch'] = np.where(df['ethnicity'] == 1, 1, 0)
df['ethnicity_other'] = np.where(df['ethnicity'] != 1, 1, 0)

In [219]:
# Columns manually identified for removal (example set from the R script)
cols_to_drop = [
    'gender', 'ethnicity', 'CIN5', 'SDV_SEXE', 'StartDatum', 'STARTDATUM', 'DROPOUT_EARLYCOMPLETER', 'TOEST_WO',
    'depressie_IN', 'TERUGKOMER', 'VROEGK_ST', 'gender_label',
    'depr_m_psychose_huid', 'depr_z_psychose_huid', 'depr_z_psychose_verl',
    'depr_m_psychose_verl', 'CAPS5score_followup', 'CAPS5_DAT_IN'
]

df = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
print("Remaining columns:", df.shape[1])


Remaining columns: 458


In [220]:

if 'BEH_DAGEN' in df.columns:
    df.rename(columns={'BEH_DAGEN': 'treatmentdurationdays'}, inplace=True)

In [221]:
# Clean and standardize column names
df.columns = (
    df.columns
    .str.replace(r"\.+", "_", regex=True)
    .str.replace(r"[^a-zA-Z0-9_]", "", regex=True)
    .str.replace(" ", "_")
    .str.strip()
)

In [222]:
# Preview key outcome variables
outcome_vars = ['CAPS5score_baseline', 'CAPS5Score_TK']
for col in outcome_vars:
    if col in df.columns:
        print(f"{col}: {df[col].isnull().sum()} missing")

# Calculate change score
if 'CAPS5score_baseline' in df.columns and 'CAPS5Score_TK' in df.columns:
    df['caps5_change_baseline'] = df['CAPS5Score_TK'] - df['CAPS5score_baseline'] 


CAPS5score_baseline: 0 missing
CAPS5Score_TK: 0 missing


In [223]:
# Define exceptions to keep
protected_cols = [
    "DIAGNOSIS_ANXIETY_OCD",
    "DIAGNOSIS_PSYCHOTIC",
    "DIAGNOSIS_EATING_DISORDER",
    "DIAGNOSIS_SUBSTANCE_DISORDER", "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", 'SUBCAT_Selectieve_immunosuppresiva', 'treatmentdurationdays',
'SUBCAT_Corticosteroiden',
'SUBCAT_Immunomodulerend_Coxibs',
'SUBCAT_Aminosalicylaten',
'SUBCAT_calcineurineremmers',
'SUBCAT_Anti_epileptica_Benzodiazepine',
'SUBCAT_Paracetamol_overig_combinatie', 'SUBCAT_MAO_remmers', 'SUBCAT_psychostimulans_overige', 'SUBCAT_Interleukine_remmers'

]


# ----------------------------------------
# 1. Drop columns with >95% missing values (except protected)
thresh_missing = int(0.95 * len(df))
missing_cols = [col for col in df.columns if df[col].isnull().sum() > (len(df) - thresh_missing)]
missing_cols_to_drop = [col for col in missing_cols if col not in protected_cols]
df = df.drop(columns=missing_cols_to_drop)

# ----------------------------------------
# 2. Drop near-zero variance columns (except protected)
low_variance_cols = [col for col in df.columns if df[col].nunique(dropna=True) <= 1 and col not in protected_cols]
df = df.drop(columns=low_variance_cols)


In [224]:
df.to_csv("cleaned_data_baseline.csv", index=False)
print(" Step 1 Complete: Cleaned dataset saved.")

 Step 1 Complete: Cleaned dataset saved.


In [225]:
# Target variables:

In [226]:
#!pip install scikit-learn
# !pip install fancyimpute

In [227]:
from sklearn.impute import SimpleImputer
from fancyimpute import IterativeImputer
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer

In [228]:
# Load the cleaned dataset from Step 1
df = pd.read_csv("cleaned_data_baseline.csv")

# Quick check
print(df.shape)
print(df.dtypes.head(10))

(6125, 204)
DIAGNOSIS_ANXIETY_OCD           float64
DIAGNOSIS_SMOKING               float64
DIAGNOSIS_EATING_DISORDER       float64
DIAGNOSIS_SUBSTANCE_DISORDER    float64
DIAGNOSIS_PSYCHOTIC             float64
DIAGNOSIS_SUICIDALITY           float64
DIAGNOSIS_SEXUAL_TRAUMA         float64
DIAGNOSIS_CHILDHOOD_TRAUMA        int64
DIAGNOSIS_CPTSD                 float64
treatmentdurationdays           float64
dtype: object


In [229]:
# Separate Numerical and Categorical Variables

In [230]:
# Identify numerical and categorical columns
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

print(f"Numerical Columns: {len(numerical_cols)}")
print(f"Categorical Columns: {len(categorical_cols)}")

Numerical Columns: 204
Categorical Columns: 0


In [231]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6125 entries, 0 to 6124
Columns: 204 entries, DIAGNOSIS_ANXIETY_OCD to ethnicity_other
dtypes: float64(10), int64(194)
memory usage: 9.5 MB


In [232]:
# Save the fully prepared data
df.to_csv("final_prepared_data.csv", index=False)
print(" Step 2 Complete: Final prepared dataset saved as 'final_prepared_data.csv'.")

 Step 2 Complete: Final prepared dataset saved as 'final_prepared_data.csv'.


In [233]:
import os
import pandas as pd
import numpy as np
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer

# ========== CONFIG ==========
save_folder = "imputed_data"
os.makedirs(save_folder, exist_ok=True)
n_imputations = 5

# ========== LOAD ==========
# Ensure df is already defined
assert 'df' in globals(), "Please load the original DataFrame as `df` before running this script."

# ========== IDENTIFY NUMERIC COLUMNS ==========
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()

# ========== STEP 1: MICE IMPUTATION ==========
print("=" * 50)
print("STEP 1: MICE IMPUTATION")
print("=" * 50)

imputed_dfs = []
for i in range(1, n_imputations + 1):
    print(f"\n=== Running MICE Imputation: Dataset {i} ===")
    #  NEW instance with different seed AND sample_posterior=True for randomness
    mice_imputer = IterativeImputer(
        max_iter=10, 
        random_state=42+i,  # Different base to avoid low numbers
        sample_posterior=True,  #  KEY: This adds randomness!
        n_nearest_features=None,
        initial_strategy='mean'
    )
    # Fit-transform on numeric columns
    imputed_array = mice_imputer.fit_transform(df[numeric_cols])
    # Replace numeric columns in a copy of the original df
    df_imputed = df.copy()
    df_imputed[numeric_cols] = pd.DataFrame(imputed_array, columns=numeric_cols, index=df.index)
    # Append to list
    imputed_dfs.append(df_imputed)
    print(f" Completed imputation {i}")

# ========== STEP 2: ROUNDING ==========
print("\n" + "=" * 50)
print("STEP 2: ROUNDING NUMERIC COLUMNS")
print("=" * 50)

def round_all_numeric_columns_all_imputations(imputed_dfs, decimals=0, verbose=True):
    rounded_dfs = []
    for i, df in enumerate(imputed_dfs):
        df_copy = df.copy()
        numeric_cols = df_copy.select_dtypes(include=[np.number]).columns
        df_copy[numeric_cols] = df_copy[numeric_cols].round(decimals)
        rounded_dfs.append(df_copy)
        if verbose:
            print(f" Imputation {i+1}: Rounded {len(numeric_cols)} numeric columns to {decimals} decimal place(s).")
    return rounded_dfs

# Apply rounding to all imputed datasets
imputed_dfs = round_all_numeric_columns_all_imputations(imputed_dfs)

# ========== STEP 3: SAVE FINAL DATASETS ==========
print("\n" + "=" * 50)
print("STEP 3: SAVING FINAL DATASETS")
print("=" * 50)

for i, df_imputed in enumerate(imputed_dfs, 1):
    # Save outputs
    pkl_path = f"{save_folder}/df_imputed_final_imp{i}.pkl"
    csv_path = f"{save_folder}/df_imputed_final_imp{i}.csv"
    excel_path = f"{save_folder}/df_imputed_final_imp{i}.xlsx"
    
    df_imputed.to_pickle(pkl_path)
    df_imputed.to_csv(csv_path, index=False)
    df_imputed.to_excel(excel_path, index=False)
    
    print(f" Saved files for imputation {i}:")
    print(f"   → {pkl_path}")
    print(f"   → {csv_path}")
    print(f"   → {excel_path}")

# ========== STEP 4: VERIFY DATASETS ARE DIFFERENT ==========
print("\n" + "=" * 50)
print("STEP 4: VERIFYING DATASET DIFFERENCES")
print("=" * 50)

def check_imputation_differences(imputed_dfs, verbose=True):
    """Check if imputed datasets are actually different from each other"""
    if len(imputed_dfs) < 2:
        print("  Only one dataset - cannot check differences")
        return
    
    # Get numeric columns that had missing values originally
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    missing_cols = [col for col in numeric_cols if df[col].isnull().any()]
    
    if not missing_cols:
        print("  No missing values found in original data")
        return
    
    print(f" Checking differences in {len(missing_cols)} columns that had missing values...")
    
    differences_found = False
    
    for col in missing_cols[:3]:  # Check first 3 columns with missing values
        # Compare first two datasets for this column
        values_1 = imputed_dfs[0][col].values
        values_2 = imputed_dfs[1][col].values
        
        if not np.array_equal(values_1, values_2):
            differences_found = True
            # Count how many values are different
            diff_count = np.sum(values_1 != values_2)
            print(f" Column '{col}': {diff_count} different values between datasets 1 & 2")
        else:
            print(f" Column '{col}': IDENTICAL values between datasets 1 & 2")
    
    if differences_found:
        print(f"\n SUCCESS: Datasets show proper variability!")
    else:
        print(f"\n  WARNING: Datasets appear identical - check random_state implementation")
    
    return differences_found

# Run the check
check_imputation_differences(imputed_dfs)

print("\n" + "=" * 50)
print(" MICE IMPUTATION COMPLETE!")
print("=" * 50)
print(f" Created {n_imputations} imputed datasets")
print(f" Applied rounding to all numeric columns")
print(f" Saved files in: {save_folder}/")
print("=" * 50)

STEP 1: MICE IMPUTATION

=== Running MICE Imputation: Dataset 1 ===


KeyboardInterrupt: 

In [None]:

# ========== METHOD 1: QUICK CHECK - Compare first 2 datasets ==========
def quick_difference_check(imputed_dfs):
    """Quick check to see if first two datasets are different"""
    if len(imputed_dfs) < 2:
        print("Need at least 2 datasets to compare")
        return
    
    df1 = imputed_dfs[0]
    df2 = imputed_dfs[1]
    
    # Check if dataframes are identical
    are_identical = df1.equals(df2)
    print(f"Dataset 1 vs Dataset 2: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")
    
    if not are_identical:
        # Count different values
        numeric_cols = df1.select_dtypes(include=[np.number]).columns
        total_diff = 0
        for col in numeric_cols:
            diff_count = np.sum(df1[col] != df2[col])
            if diff_count > 0:
                total_diff += diff_count
                print(f"  '{col}': {diff_count} different values")
        print(f"  Total different values: {total_diff}")

# ========== METHOD 2: DETAILED CHECK - All pairwise comparisons ==========
def detailed_difference_check(imputed_dfs):
    """Check differences between all pairs of datasets"""
    n_datasets = len(imputed_dfs)
    print(f"\n=== Checking all {n_datasets} datasets ===")
    
    numeric_cols = imputed_dfs[0].select_dtypes(include=[np.number]).columns
    
    for i in range(n_datasets):
        for j in range(i+1, n_datasets):
            are_identical = imputed_dfs[i].equals(imputed_dfs[j])
            print(f"Dataset {i+1} vs Dataset {j+1}: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")

# ========== METHOD 3: FOCUS ON ORIGINALLY MISSING VALUES ==========
def check_missing_value_differences(original_df, imputed_dfs):
    """Check differences only in originally missing positions"""
    print(f"\n=== Checking differences in originally missing positions ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    differences_found = False
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_mask = original_df[col].isnull()
            print(f"\nColumn '{col}' ({missing_mask.sum()} missing values):")
            
            # Compare imputed values at missing positions
            for i in range(len(imputed_dfs)-1):
                imp1_values = imputed_dfs[i].loc[missing_mask, col]
                imp2_values = imputed_dfs[i+1].loc[missing_mask, col]
                
                are_same = np.array_equal(imp1_values.values, imp2_values.values)
                if not are_same:
                    differences_found = True
                    diff_count = np.sum(imp1_values.values != imp2_values.values)
                    print(f"  Dataset {i+1} vs {i+2}: {diff_count}/{len(imp1_values)} different imputed values ✅")
                else:
                    print(f"  Dataset {i+1} vs {i+2}: IDENTICAL imputed values ❌")
    
    return differences_found

# ========== METHOD 4: SAMPLE VALUES FROM EACH DATASET ==========
def show_sample_imputed_values(original_df, imputed_dfs, n_samples=5):
    """Show sample imputed values from each dataset"""
    print(f"\n=== Sample imputed values (first {n_samples} missing positions) ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_positions = original_df[original_df[col].isnull()].index[:n_samples]
            
            print(f"\nColumn '{col}' at positions {list(missing_positions)}:")
            for i, df_imp in enumerate(imputed_dfs):
                values = df_imp.loc[missing_positions, col].values
                print(f"  Dataset {i+1}: {values}")

# ========== RUN ALL CHECKS ==========
print("=" * 60)
print("CHECKING IMPUTATION DIFFERENCES")
print("=" * 60)

# Method 1: Quick check
quick_difference_check(imputed_dfs)

# Method 2: All pairwise comparisons  
detailed_difference_check(imputed_dfs)

# Method 3: Focus on originally missing values (assumes 'df' is your original dataframe)
if 'df' in globals():
    differences_found = check_missing_value_differences(df, imputed_dfs)
    if differences_found:
        print(f"\n🎉 SUCCESS: Found differences in imputed values!")
    else:
        print(f"\n⚠️ WARNING: No differences found in imputed values!")

# Method 4: Show sample values
if 'df' in globals():
    show_sample_imputed_values(df, imputed_dfs, n_samples=3)

In [None]:
imputed_folder = "imputed_data"
n_imputations = 5

# Lists to hold DataFrames and Y vectors
imputed_dfs = []
Y_list = []

for i in range(1, n_imputations + 1):
    file_path = f"{imputed_folder}/df_imputed_final_imp{i}.pkl"
    
    # Load imputed DataFrame
    df_imp = pd.read_pickle(file_path)
    imputed_dfs.append(df_imp)

    # Define Y for this imputation
    Y = df_imp["caps5_change_baseline"]
    Y_list.append(Y)

    print(f"Y for imputation {i} defined. Sample values:")
    print(Y.head())


In [None]:
covariates_CAT_ADHD = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'CAT_Z_drugs', 'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Aceetanilidederivaten = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_CAT_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihypertensiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder"
]


covariates_CAT_Anti_epileptica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antipsychotica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_BENZO = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Z_drugs', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL_PSYCHOTROPICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA',
    'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age', 'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_CAT_ or covariates_cat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_cat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

In [None]:
import os


def run_all_CAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each CAT medication group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_cat_<group>
    - Outputs are saved in: outputs/CAT_<GROUP>/
    """

    print(" Starting analysis for all CAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_cat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., cat_z_drugs → Cat_Z_Drugs
            group_name = group_name.replace("Cat_", "CAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All CAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_CAT_group_models(imputed_dfs)


In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")


In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)


In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")


In [None]:
# Weighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train, sample_weight=W_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train, sample_weight=W_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_cats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)

In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Plotting Functions
# -----------------------------
def create_diagnostic_plots(residuals_data, group_name, output_folder):
    """Create diagnostic plots for each group"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    all_residuals = residuals_data['residuals']
    all_fitted = residuals_data['fitted']
    
    if len(all_residuals) == 0:
        return
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=14, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0, 0].scatter(all_fitted, all_residuals, alpha=0.5, s=1)
    axes[0, 0].axhline(y=0, color='red', linestyle='--')
    axes[0, 0].set_xlabel('Fitted Values')
    axes[0, 0].set_ylabel('Residuals')
    axes[0, 0].set_title('Residuals vs Fitted')
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0, 1])
    axes[0, 1].set_title('QQ Plot (Normal)')
    
    # 3. Histogram of Residuals
    axes[1, 0].hist(all_residuals, bins=50, alpha=0.7, edgecolor='black')
    axes[1, 0].axvline(x=0, color='red', linestyle='--')
    axes[1, 0].set_xlabel('Residuals')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Residual Distribution')
    
    # 4. Scale-Location Plot
    sqrt_abs_resid = np.sqrt(np.abs(all_residuals))
    axes[1, 1].scatter(all_fitted, sqrt_abs_resid, alpha=0.5, s=1)
    axes[1, 1].set_xlabel('Fitted Values')
    axes[1, 1].set_ylabel('√|Residuals|')
    axes[1, 1].set_title('Scale-Location Plot')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f'{group_name}_unweighted.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1, s0 = treated[col].std(), control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize residuals collection for this group
        group_residuals_data = {'residuals': [], 'fitted': []}

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            
                            # Collect residuals and fitted values for plotting
                            group_residuals_data['residuals'].extend(residuals.tolist())
                            group_residuals_data['fitted'].extend(Y_pred.tolist())
                            
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")
        
        # Create diagnostic plots for this group
        print(f"📊 Creating diagnostic plots for {group}...")
        create_diagnostic_plots(group_residuals_data, group, output_folder)

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_cats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")
    print("📊 All diagnostic plots saved in outputs/plots/ folder.")

run_dml_with_trimmed_data(final_covariates_map)


In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "dml_rubin_summary_cats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: dml_rubin_summary_cats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_Cat.xlsx", index=False)
print("✅ Final_ATT_Summary_Cat saved.xlsx")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_Cat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("dml_att_barplot_cat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)

print("✅ dml_att_barplot_cat saved.xlsx")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")


In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")


## Subcat analysis:

In [None]:
covariates_SUBCAT_Antipsychotica_atypisch = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_TCA = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SSRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SNRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Tetracyclische_antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Antidepressiva_overige = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Systemische_antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_anxiolytica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_hypnotica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Amfetaminen = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Systemische_betablokkers = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Paracetamol_mono = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Anti_epileptica_stemmingsstabilisatoren = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age', 
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_SUBCAT_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)


In [None]:
import os


def run_all_SUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBCAT medisubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subcat_<group>
    - Outputs are saved in: outputs/SUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subcat_z_drugs → Subcat_Z_Drugs
            group_name = group_name.replace("Subcat_", "SUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBCAT_group_models(imputed_dfs)


In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")

In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]


# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

In [None]:
# ATT calculation:

In [None]:
# Weighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train, sample_weight=W_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train, sample_weight=W_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_subcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)

In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Plotting Functions
# -----------------------------
def create_diagnostic_plots(residuals_data, group_name, output_folder):
    """Create diagnostic plots for each group"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    all_residuals = residuals_data['residuals']
    all_fitted = residuals_data['fitted']
    
    if len(all_residuals) == 0:
        return
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=14, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0, 0].scatter(all_fitted, all_residuals, alpha=0.5, s=1)
    axes[0, 0].axhline(y=0, color='red', linestyle='--')
    axes[0, 0].set_xlabel('Fitted Values')
    axes[0, 0].set_ylabel('Residuals')
    axes[0, 0].set_title('Residuals vs Fitted')
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0, 1])
    axes[0, 1].set_title('QQ Plot (Normal)')
    
    # 3. Histogram of Residuals
    axes[1, 0].hist(all_residuals, bins=50, alpha=0.7, edgecolor='black')
    axes[1, 0].axvline(x=0, color='red', linestyle='--')
    axes[1, 0].set_xlabel('Residuals')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Residual Distribution')
    
    # 4. Scale-Location Plot
    sqrt_abs_resid = np.sqrt(np.abs(all_residuals))
    axes[1, 1].scatter(all_fitted, sqrt_abs_resid, alpha=0.5, s=1)
    axes[1, 1].set_xlabel('Fitted Values')
    axes[1, 1].set_ylabel('√|Residuals|')
    axes[1, 1].set_title('Scale-Location Plot')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f'{group_name}_unweighted.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1, s0 = treated[col].std(), control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize residuals collection for this group
        group_residuals_data = {'residuals': [], 'fitted': []}

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            
                            # Collect residuals and fitted values for plotting
                            group_residuals_data['residuals'].extend(residuals.tolist())
                            group_residuals_data['fitted'].extend(Y_pred.tolist())
                            
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")
        
        # Create diagnostic plots for this group
        print(f"📊 Creating diagnostic plots for {group}...")
        create_diagnostic_plots(group_residuals_data, group, output_folder)

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_subcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")
    print("📊 All diagnostic plots saved in outputs/plots/ folder.")

run_dml_with_trimmed_data(final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "dml_rubin_summary_subcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: dml_rubin_summary_subcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubCat saved")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("dml_att_barplot_subcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ dml_att_barplot_subcat saved")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")


In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")

## SubSubCat Analysis:

In [None]:
covariates_SubSubCat_Oxazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Diazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Paracetamol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Lorazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Mirtazapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Escitalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Sertraline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Temazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Citalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Quetiapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]




covariates_SubSubCat_Amitriptyline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Venlafaxine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Fluoxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Topiramaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Tramadol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica', 'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_SubSubCat_Zopiclon = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Loprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Alprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_promethazine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Paroxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Bupropion = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Methylfenidaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Olanzapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Zolpidem = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_SUbSubCAT_ or covariates_SubSubcat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subsubcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

In [None]:
import os


def run_all_SUBSUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBSUBCAT medisubsubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subsubcat_<group>
    - Outputs are saved in: outputs/SUBSUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBSUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subsubcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subsubcat_z_drugs → Subsubcat_Z_Drugs
            group_name = group_name.replace("Subsubcat_", "SUBSUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBSUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBSUBCAT_group_models(imputed_dfs)

In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")

In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

In [None]:
# ATT calculation:

In [None]:
# Weighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train, sample_weight=W_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train, sample_weight=W_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_subsubcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)

In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from econml.dml import LinearDML
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 4
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Plotting Functions
# -----------------------------
def create_diagnostic_plots(residuals_data, group_name, output_folder):
    """Create diagnostic plots for each group"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    all_residuals = residuals_data['residuals']
    all_fitted = residuals_data['fitted']
    
    if len(all_residuals) == 0:
        return
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=14, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0, 0].scatter(all_fitted, all_residuals, alpha=0.5, s=1)
    axes[0, 0].axhline(y=0, color='red', linestyle='--')
    axes[0, 0].set_xlabel('Fitted Values')
    axes[0, 0].set_ylabel('Residuals')
    axes[0, 0].set_title('Residuals vs Fitted')
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0, 1])
    axes[0, 1].set_title('QQ Plot (Normal)')
    
    # 3. Histogram of Residuals
    axes[1, 0].hist(all_residuals, bins=50, alpha=0.7, edgecolor='black')
    axes[1, 0].axvline(x=0, color='red', linestyle='--')
    axes[1, 0].set_xlabel('Residuals')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Residual Distribution')
    
    # 4. Scale-Location Plot
    sqrt_abs_resid = np.sqrt(np.abs(all_residuals))
    axes[1, 1].scatter(all_fitted, sqrt_abs_resid, alpha=0.5, s=1)
    axes[1, 1].set_xlabel('Fitted Values')
    axes[1, 1].set_ylabel('√|Residuals|')
    axes[1, 1].set_title('Scale-Location Plot')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f'{group_name}_unweighted.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1, s0 = treated[col].std(), control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# DML Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running DML for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize residuals collection for this group
        group_residuals_data = {'residuals': [], 'fitted': []}

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                            )

                            model_y = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            model_t = xgb.XGBClassifier(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1,
                                                        use_label_encoder=False, eval_metric="logloss", random_state=seed)

                            dml = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True,
                                            cv=KFold(n_splits=3, shuffle=True, random_state=seed), random_state=seed)
                            dml.fit(Y_train, T_train, X=X_train)

                            tau = dml.effect(X_train)
                            att = np.mean(tau)
                            influence = tau - att
                            se = np.sqrt(np.mean(influence ** 2) / len(tau))

                            att_list.append(att)
                            se_list.append(se)

                            Y_pred = model_y.fit(X_train, Y_train).predict(X_train)
                            residuals = Y_train - Y_pred
                            
                            # Collect residuals and fitted values for plotting
                            group_residuals_data['residuals'].extend(residuals.tolist())
                            group_residuals_data['fitted'].extend(Y_pred.tolist())
                            
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")
        
        # Create diagnostic plots for this group
        print(f"📊 Creating diagnostic plots for {group}...")
        create_diagnostic_plots(group_residuals_data, group, output_folder)

    # Save final output
    pd.DataFrame(att_results).to_excel("dml_rubin_summary_subsubcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")
    print("📊 All diagnostic plots saved in outputs/plots/ folder.")

run_dml_with_trimmed_data(final_covariates_map)


In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "dml_rubin_summary_subsubcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: dml_rubin_summary_subsubcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubSubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubSubCat saved")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubSubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("dml_att_barplot_subsubcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ dml_att_barplot_subsubcat is saved.")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")


In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")

#### XGBOOST:

In [None]:
import os
import warnings

warnings.filterwarnings("ignore")  # optional

# Option 1 (recommended)
new_path = r"D:\Work\PTSD\XGBoost"

os.chdir(new_path)  # Change working directory
print("Changed to:", os.getcwd())

In [None]:
import pandas as pd
df = pd.read_csv(r"data_baseline.csv")

## CAT analysis:

In [None]:
# Import all necessary packages
import pandas as pd
import numpy as np
import re
# For visualization and future steps
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.experimental import enable_iterative_imputer  # Needed to enable the experimental feature
from sklearn.impute import IterativeImputer

In [None]:
# Check basic structure
print("Shape of dataset:", df.shape)
print("\nSample columns:", df.columns.tolist()[:10])
print("\nMissing values:\n", df.isnull().sum().sort_values(ascending=False).head(10))

# Remove duplicate rows
df = df.drop_duplicates()

# Confirm shape after removing duplicates
print("Shape after removing duplicates:", df.shape)

In [None]:
import pyreadstat

# Load gender info
gender_df, meta = pyreadstat.read_sav("SDV_IN_Gender_2019_2024.sav")

# Just extract SDV_SEXE column and append to df
df["SDV_SEXE"] = gender_df["SDV_SEXE"].reset_index(drop=True)

# Optional: map to labels
gender_map = {1.0: "Male", 2.0: "Female", 3.0: "Other"}
df["gender_label"] = df["SDV_SEXE"].map(gender_map)

# Done! Check a sample
print(df[["SDV_SEXE", "gender_label"]].value_counts())

In [None]:
# Assuming 'gender' and 'SDV_SEXE' columns are in df

# Gender dummy variables
df['gender_1'] = (df['gender'] == 1).astype(int)
df['gender_2'] = (df['gender'] == 2).astype(int)

# SDV_SEXE dummy variables
df['SDV_SEXE_1'] = (df['SDV_SEXE'] == 1).astype(int)
df['SDV_SEXE_2'] = (df['SDV_SEXE'] == 2).astype(int)
df['SDV_SEXE_3'] = (df['SDV_SEXE'] == 3).astype(int)

# Create binary columns
df['ethnicity_Dutch'] = np.where(df['ethnicity'] == 1, 1, 0)
df['ethnicity_other'] = np.where(df['ethnicity'] != 1, 1, 0)

In [None]:
# Columns manually identified for removal (example set from the R script)
cols_to_drop = [
    'gender', 'ethnicity', 'CIN5', 'SDV_SEXE', 'StartDatum', 'STARTDATUM', 'DROPOUT_EARLYCOMPLETER', 'TOEST_WO',
    'depressie_IN', 'TERUGKOMER', 'VROEGK_ST', 'gender_label',
    'depr_m_psychose_huid', 'depr_z_psychose_huid', 'depr_z_psychose_verl',
    'depr_m_psychose_verl', 'CAPS5score_followup', 'CAPS5_DAT_IN'
]

df = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
print("Remaining columns:", df.shape[1])

In [None]:
if 'BEH_DAGEN' in df.columns:
    df.rename(columns={'BEH_DAGEN': 'treatmentdurationdays'}, inplace=True)

In [None]:
# Clean and standardize column names
df.columns = (
    df.columns
    .str.replace(r"\.+", "_", regex=True)
    .str.replace(r"[^a-zA-Z0-9_]", "", regex=True)
    .str.replace(" ", "_")
    .str.strip()
)

In [None]:
# Preview key outcome variables
outcome_vars = ['CAPS5score_baseline', 'CAPS5Score_TK']
for col in outcome_vars:
    if col in df.columns:
        print(f"{col}: {df[col].isnull().sum()} missing")

# Calculate change score
if 'CAPS5score_baseline' in df.columns and 'CAPS5Score_TK' in df.columns:
    df['caps5_change_baseline'] = df['CAPS5Score_TK'] - df['CAPS5score_baseline'] 

In [None]:
# Define exceptions to keep
protected_cols = [
    "DIAGNOSIS_ANXIETY_OCD",
    "DIAGNOSIS_PSYCHOTIC",
    "DIAGNOSIS_EATING_DISORDER",
    "DIAGNOSIS_SUBSTANCE_DISORDER", "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", 'SUBCAT_Selectieve_immunosuppresiva', 'treatmentdurationdays',
'SUBCAT_Corticosteroiden',
'SUBCAT_Immunomodulerend_Coxibs',
'SUBCAT_Aminosalicylaten',
'SUBCAT_calcineurineremmers',
'SUBCAT_Anti_epileptica_Benzodiazepine',
'SUBCAT_Paracetamol_overig_combinatie', 'SUBCAT_MAO_remmers', 'SUBCAT_psychostimulans_overige', 'SUBCAT_Interleukine_remmers'

]


# ----------------------------------------
# 1. Drop columns with >95% missing values (except protected)
thresh_missing = int(0.95 * len(df))
missing_cols = [col for col in df.columns if df[col].isnull().sum() > (len(df) - thresh_missing)]
missing_cols_to_drop = [col for col in missing_cols if col not in protected_cols]
df = df.drop(columns=missing_cols_to_drop)

# ----------------------------------------
# 2. Drop near-zero variance columns (except protected)
low_variance_cols = [col for col in df.columns if df[col].nunique(dropna=True) <= 1 and col not in protected_cols]
df = df.drop(columns=low_variance_cols)

In [None]:
df.to_csv("cleaned_data_baseline.csv", index=False)
print(" Step 1 Complete: Cleaned dataset saved.")

In [None]:
from sklearn.impute import SimpleImputer
from fancyimpute import IterativeImputer
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer

In [None]:
# Load the cleaned dataset from Step 1
df = pd.read_csv("cleaned_data_baseline.csv")

# Quick check
print(df.shape)
print(df.dtypes.head(10))

In [None]:
# Separate Numerical and Categorical Variables

In [None]:
# Identify numerical and categorical columns
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

print(f"Numerical Columns: {len(numerical_cols)}")
print(f"Categorical Columns: {len(categorical_cols)}")

In [None]:
df.info()

In [None]:
# Save the fully prepared data
df.to_csv("final_prepared_data.csv", index=False)
print(" Step 2 Complete: Final prepared dataset saved as 'final_prepared_data.csv'.")

In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer

# ========== CONFIG ==========
save_folder = "imputed_data"
os.makedirs(save_folder, exist_ok=True)
n_imputations = 5

# ========== LOAD ==========
# Ensure df is already defined
assert 'df' in globals(), "Please load the original DataFrame as `df` before running this script."

# ========== IDENTIFY NUMERIC COLUMNS ==========
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()

# ========== STEP 1: MICE IMPUTATION ==========
print("=" * 50)
print("STEP 1: MICE IMPUTATION")
print("=" * 50)

imputed_dfs = []
for i in range(1, n_imputations + 1):
    print(f"\n=== Running MICE Imputation: Dataset {i} ===")
    #  NEW instance with different seed AND sample_posterior=True for randomness
    mice_imputer = IterativeImputer(
        max_iter=10, 
        random_state=42+i,  # Different base to avoid low numbers
        sample_posterior=True,  #  KEY: This adds randomness!
        n_nearest_features=None,
        initial_strategy='mean'
    )
    # Fit-transform on numeric columns
    imputed_array = mice_imputer.fit_transform(df[numeric_cols])
    # Replace numeric columns in a copy of the original df
    df_imputed = df.copy()
    df_imputed[numeric_cols] = pd.DataFrame(imputed_array, columns=numeric_cols, index=df.index)
    # Append to list
    imputed_dfs.append(df_imputed)
    print(f" Completed imputation {i}")

# ========== STEP 2: ROUNDING ==========
print("\n" + "=" * 50)
print("STEP 2: ROUNDING NUMERIC COLUMNS")
print("=" * 50)

def round_all_numeric_columns_all_imputations(imputed_dfs, decimals=0, verbose=True):
    rounded_dfs = []
    for i, df in enumerate(imputed_dfs):
        df_copy = df.copy()
        numeric_cols = df_copy.select_dtypes(include=[np.number]).columns
        df_copy[numeric_cols] = df_copy[numeric_cols].round(decimals)
        rounded_dfs.append(df_copy)
        if verbose:
            print(f" Imputation {i+1}: Rounded {len(numeric_cols)} numeric columns to {decimals} decimal place(s).")
    return rounded_dfs

# Apply rounding to all imputed datasets
imputed_dfs = round_all_numeric_columns_all_imputations(imputed_dfs)

# ========== STEP 3: SAVE FINAL DATASETS ==========
print("\n" + "=" * 50)
print("STEP 3: SAVING FINAL DATASETS")
print("=" * 50)

for i, df_imputed in enumerate(imputed_dfs, 1):
    # Save outputs
    pkl_path = f"{save_folder}/df_imputed_final_imp{i}.pkl"
    csv_path = f"{save_folder}/df_imputed_final_imp{i}.csv"
    excel_path = f"{save_folder}/df_imputed_final_imp{i}.xlsx"
    
    df_imputed.to_pickle(pkl_path)
    df_imputed.to_csv(csv_path, index=False)
    df_imputed.to_excel(excel_path, index=False)
    
    print(f" Saved files for imputation {i}:")
    print(f"   → {pkl_path}")
    print(f"   → {csv_path}")
    print(f"   → {excel_path}")

# ========== STEP 4: VERIFY DATASETS ARE DIFFERENT ==========
print("\n" + "=" * 50)
print("STEP 4: VERIFYING DATASET DIFFERENCES")
print("=" * 50)

def check_imputation_differences(imputed_dfs, verbose=True):
    """Check if imputed datasets are actually different from each other"""
    if len(imputed_dfs) < 2:
        print("  Only one dataset - cannot check differences")
        return
    
    # Get numeric columns that had missing values originally
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    missing_cols = [col for col in numeric_cols if df[col].isnull().any()]
    
    if not missing_cols:
        print("  No missing values found in original data")
        return
    
    print(f" Checking differences in {len(missing_cols)} columns that had missing values...")
    
    differences_found = False
    
    for col in missing_cols[:3]:  # Check first 3 columns with missing values
        # Compare first two datasets for this column
        values_1 = imputed_dfs[0][col].values
        values_2 = imputed_dfs[1][col].values
        
        if not np.array_equal(values_1, values_2):
            differences_found = True
            # Count how many values are different
            diff_count = np.sum(values_1 != values_2)
            print(f" Column '{col}': {diff_count} different values between datasets 1 & 2")
        else:
            print(f" Column '{col}': IDENTICAL values between datasets 1 & 2")
    
    if differences_found:
        print(f"\n SUCCESS: Datasets show proper variability!")
    else:
        print(f"\n  WARNING: Datasets appear identical - check random_state implementation")
    
    return differences_found

# Run the check
check_imputation_differences(imputed_dfs)

print("\n" + "=" * 50)
print(" MICE IMPUTATION COMPLETE!")
print("=" * 50)
print(f" Created {n_imputations} imputed datasets")
print(f" Applied rounding to all numeric columns")
print(f" Saved files in: {save_folder}/")
print("=" * 50)

In [None]:
import pandas as pd
import numpy as np

# ========== METHOD 1: QUICK CHECK - Compare first 2 datasets ==========
def quick_difference_check(imputed_dfs):
    """Quick check to see if first two datasets are different"""
    if len(imputed_dfs) < 2:
        print("Need at least 2 datasets to compare")
        return
    
    df1 = imputed_dfs[0]
    df2 = imputed_dfs[1]
    
    # Check if dataframes are identical
    are_identical = df1.equals(df2)
    print(f"Dataset 1 vs Dataset 2: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")
    
    if not are_identical:
        # Count different values
        numeric_cols = df1.select_dtypes(include=[np.number]).columns
        total_diff = 0
        for col in numeric_cols:
            diff_count = np.sum(df1[col] != df2[col])
            if diff_count > 0:
                total_diff += diff_count
                print(f"  '{col}': {diff_count} different values")
        print(f"  Total different values: {total_diff}")

# ========== METHOD 2: DETAILED CHECK - All pairwise comparisons ==========
def detailed_difference_check(imputed_dfs):
    """Check differences between all pairs of datasets"""
    n_datasets = len(imputed_dfs)
    print(f"\n=== Checking all {n_datasets} datasets ===")
    
    numeric_cols = imputed_dfs[0].select_dtypes(include=[np.number]).columns
    
    for i in range(n_datasets):
        for j in range(i+1, n_datasets):
            are_identical = imputed_dfs[i].equals(imputed_dfs[j])
            print(f"Dataset {i+1} vs Dataset {j+1}: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")

# ========== METHOD 3: FOCUS ON ORIGINALLY MISSING VALUES ==========
def check_missing_value_differences(original_df, imputed_dfs):
    """Check differences only in originally missing positions"""
    print(f"\n=== Checking differences in originally missing positions ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    differences_found = False
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_mask = original_df[col].isnull()
            print(f"\nColumn '{col}' ({missing_mask.sum()} missing values):")
            
            # Compare imputed values at missing positions
            for i in range(len(imputed_dfs)-1):
                imp1_values = imputed_dfs[i].loc[missing_mask, col]
                imp2_values = imputed_dfs[i+1].loc[missing_mask, col]
                
                are_same = np.array_equal(imp1_values.values, imp2_values.values)
                if not are_same:
                    differences_found = True
                    diff_count = np.sum(imp1_values.values != imp2_values.values)
                    print(f"  Dataset {i+1} vs {i+2}: {diff_count}/{len(imp1_values)} different imputed values ✅")
                else:
                    print(f"  Dataset {i+1} vs {i+2}: IDENTICAL imputed values ❌")
    
    return differences_found

# ========== METHOD 4: SAMPLE VALUES FROM EACH DATASET ==========
def show_sample_imputed_values(original_df, imputed_dfs, n_samples=5):
    """Show sample imputed values from each dataset"""
    print(f"\n=== Sample imputed values (first {n_samples} missing positions) ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_positions = original_df[original_df[col].isnull()].index[:n_samples]
            
            print(f"\nColumn '{col}' at positions {list(missing_positions)}:")
            for i, df_imp in enumerate(imputed_dfs):
                values = df_imp.loc[missing_positions, col].values
                print(f"  Dataset {i+1}: {values}")

# ========== RUN ALL CHECKS ==========
print("=" * 60)
print("CHECKING IMPUTATION DIFFERENCES")
print("=" * 60)

# Method 1: Quick check
quick_difference_check(imputed_dfs)

# Method 2: All pairwise comparisons  
detailed_difference_check(imputed_dfs)

# Method 3: Focus on originally missing values (assumes 'df' is your original dataframe)
if 'df' in globals():
    differences_found = check_missing_value_differences(df, imputed_dfs)
    if differences_found:
        print(f"\n🎉 SUCCESS: Found differences in imputed values!")
    else:
        print(f"\n⚠️ WARNING: No differences found in imputed values!")

# Method 4: Show sample values
if 'df' in globals():
    show_sample_imputed_values(df, imputed_dfs, n_samples=3)

In [None]:
imputed_folder = "imputed_data"
n_imputations = 5

# Lists to hold DataFrames and Y vectors
imputed_dfs = []
Y_list = []

for i in range(1, n_imputations + 1):
    file_path = f"{imputed_folder}/df_imputed_final_imp{i}.pkl"
    
    # Load imputed DataFrame
    df_imp = pd.read_pickle(file_path)
    imputed_dfs.append(df_imp)

    # Define Y for this imputation
    Y = df_imp["caps5_change_baseline"]
    Y_list.append(Y)

    print(f"Y for imputation {i} defined. Sample values:")
    print(Y.head())

In [None]:
import pandas as pd

# Load imputed DataFrames from saved files
imputed_folder = "imputed_data"
n_imputations = 5

for i in range(1, n_imputations + 1):
    print(f"\n=== Imputed Dataset {i} ===")

    # Load each imputed dataset
    df_imp = pd.read_pickle(f"{imputed_folder}/df_imputed_final_imp{i}.pkl")

    # Get all CAT_* columns
    cat_columns = [col for col in df_imp.columns if col.startswith('CAT_')]

    print("Medication Groups:")
    print(cat_columns)
    print("Total Medication Groups Found:", len(cat_columns))

In [None]:
covariates_CAT_ADHD = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'CAT_Z_drugs', 'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Aceetanilidederivaten = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_CAT_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihypertensiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder"
]


covariates_CAT_Anti_epileptica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antipsychotica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_BENZO = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Z_drugs', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL_PSYCHOTROPICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA',
    'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age', 'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_CAT_ or covariates_cat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_cat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

In [None]:
import os


def run_all_CAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each CAT medication group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_cat_<group>
    - Outputs are saved in: outputs/CAT_<GROUP>/
    """

    print(" Starting analysis for all CAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_cat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., cat_z_drugs → Cat_Z_Drugs
            group_name = group_name.replace("Cat_", "CAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All CAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_CAT_group_models(imputed_dfs)

In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")

In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

In [None]:
# ATT calculation:

In [None]:
# Weighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train, sample_weight=W_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.average(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask], 
                                               weights=W_train[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((W_train[treated_mask] * residual) ** 2)) / np.sum(W_train[treated_mask])

                                
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_cats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)

In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.mean(treated[col]), np.mean(control[col])
            s1 = np.std(treated[col])
            s0 = np.std(control[col])
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                #W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train = X.iloc[train_idx]
                            T_train = T.iloc[train_idx]
                            Y_train = Y.iloc[train_idx]
                        

                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.mean(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((residual) ** 2)) / np.sum(treated_mask)
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_cats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "xgb_rubin_summary_cats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: xgb_rubin_summary_cats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_Cat.xlsx", index=False)
print("✅ Final_ATT_Summary_Cat saved")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_Cat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("xgb_att_barplot_cat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ xgb_att_barplot_cat saved")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")

In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")

## Subcat analysis:

In [None]:
covariates_SUBCAT_Antipsychotica_atypisch = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_TCA = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SSRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SNRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Tetracyclische_antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Antidepressiva_overige = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Systemische_antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_anxiolytica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_hypnotica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Amfetaminen = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Systemische_betablokkers = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Paracetamol_mono = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Anti_epileptica_stemmingsstabilisatoren = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age', 
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_SUBCAT_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

In [None]:
import os


def run_all_SUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBCAT medisubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subcat_<group>
    - Outputs are saved in: outputs/SUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subcat_z_drugs → Subcat_Z_Drugs
            group_name = group_name.replace("Subcat_", "SUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBCAT_group_models(imputed_dfs)

In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")

In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]


# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

In [None]:
# ATT calculation:

In [None]:
# Weighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train, sample_weight=W_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.average(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask], 
                                               weights=W_train[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((W_train[treated_mask] * residual) ** 2)) / np.sum(W_train[treated_mask])

                                
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_subcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)


In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.mean(treated[col]), np.mean(control[col])
            s1 = np.std(treated[col])
            s0 = np.std(control[col])
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                #W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train = X.iloc[train_idx]
                            T_train = T.iloc[train_idx]
                            Y_train = Y.iloc[train_idx]
                            

                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.mean(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((residual) ** 2)) / np.sum(treated_mask)
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_subcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)


In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "xgb_rubin_summary_subcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: xgb_rubin_summary_subcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubCat saved")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("xgb_att_barplot_subcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ xgb_att_barplot_subcat saved")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")

In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")

## SubSubCat Analysis:

In [None]:
covariates_SubSubCat_Oxazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Diazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Paracetamol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Lorazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Mirtazapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Escitalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Sertraline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Temazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Citalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Quetiapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]




covariates_SubSubCat_Amitriptyline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Venlafaxine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Fluoxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Topiramaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Tramadol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica', 'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_SubSubCat_Zopiclon = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Loprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Alprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_promethazine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Paroxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Bupropion = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Methylfenidaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Olanzapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Zolpidem = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [None]:
from collections import defaultdict

# This finds all variables that start with covariates_SUbSubCAT_ or covariates_SubSubcat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subsubcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

In [None]:
import os


def run_all_SUBSUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBSUBCAT medisubsubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subsubcat_<group>
    - Outputs are saved in: outputs/SUBSUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBSUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subsubcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subsubcat_z_drugs → Subsubcat_Z_Drugs
            group_name = group_name.replace("Subsubcat_", "SUBSUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBSUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBSUBCAT_group_models(imputed_dfs)

In [None]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")

In [None]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_xgboost_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from xgboost import XGBClassifier

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
                model.fit(X, T)

                # Get feature importance
                importances = model.get_booster().get_score(importance_type='gain')
                df_feat = pd.DataFrame.from_dict(importances, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

# ATT calculation:

In [None]:
# Weighted

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train, T_train, Y_train, W_train = (
                                X.iloc[train_idx],
                                T.iloc[train_idx],
                                Y.iloc[train_idx],
                                W.iloc[train_idx],
                            )

                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train, sample_weight=W_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.average(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask], 
                                               weights=W_train[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((W_train[treated_mask] * residual) ** 2)) / np.sum(W_train[treated_mask])

                                
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train, W_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_subsubcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)


In [None]:
# Unweighted:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
n_repeats = 1
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.mean(treated[col]), np.mean(control[col])
            s1 = np.std(treated[col])
            s0 = np.std(control[col])
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# XGBoost Main Loop (No DML)
# -----------------------------
def run_xgboost_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running XGBoost for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                X = df[covariates].copy()
                T = df[group]
                Y = df["caps5_change_baseline"]
                #W = df["iptw"]

                for repeat in range(n_repeats):
                    kf = KFold(n_splits=5, shuffle=True, random_state=seed + repeat)
                    for train_idx, test_idx in kf.split(X):
                        try:
                            X_train = X.iloc[train_idx]
                            T_train = T.iloc[train_idx]
                            Y_train = Y.iloc[train_idx]
                            
                            # XGBoost regression model
                            model = xgb.XGBRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=1, random_state=seed)
                            
                            # Add treatment variable to features
                            X_train_with_T = X_train.copy()
                            X_train_with_T[group] = T_train
                            
                            # Fit model
                            model.fit(X_train_with_T, Y_train)
                            
                            # Predict outcomes for treated and control groups
                            X_treated = X_train.copy()
                            X_treated[group] = 1
                            X_control = X_train.copy()
                            X_control[group] = 0
                            
                            Y_pred_treated = model.predict(X_treated)
                            Y_pred_control = model.predict(X_control)
                            
                            # Calculate ATT (Average Treatment Effect on Treated)
                            treated_mask = T_train == 1
                            if np.any(treated_mask):
                                att = np.mean(Y_pred_treated[treated_mask] - Y_pred_control[treated_mask])
                                
                                # Calculate standard error (approximate)
                                treatment_effects = Y_pred_treated[treated_mask] - Y_pred_control[treated_mask]
                                residual = treatment_effects - att
                                se = np.sqrt(np.sum((residual) ** 2)) / np.sum(treated_mask)
                                att_list.append(att)
                                se_list.append(se)

                            # Model performance metrics
                            Y_pred = model.predict(X_train_with_T)
                            residuals = Y_train - Y_pred
                            rmse = mean_squared_error(Y_train, Y_pred, squared=False)
                            r2 = r2_score(Y_train, Y_pred)
                            r2_list.append(r2)
                            rmse_list.append(rmse)
                            
                            # Collect residuals and fitted values for diagnostic plots
                            group_residuals.append(residuals.values)
                            group_fitted.append(Y_pred)

                            smd, vr = calculate_smd_vr(X_train, T_train)
                            smd_list.append(smd)
                            vr_list.append(vr)
                        except Exception as e:
                            print(f"⚠️ Error in {group}, seed {seed}, imp {imp}, rep {repeat}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("xgb_rubin_summary_subsubcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_xgboost_with_trimmed_data(final_covariates_map)


In [None]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "xgb_rubin_summary_subsubcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: xgb_rubin_summary_subsubcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubSubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubSubCat saved.xlsx")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubSubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("dml_att_barplot_subsubcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)

print("✅ dml_att_barplot_subsubcat saved.xlsx")

In [None]:
# Love plot:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")

In [None]:
# Heatmap:

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")

#### Linear:

In [1]:
import os
import warnings

warnings.filterwarnings("ignore")  # optional

# Option 1 (recommended)
new_path = r"D:\Work\PTSD\Linear"

os.chdir(new_path)  # Change working directory
print("Changed to:", os.getcwd())

Changed to: D:\Work\PTSD\Linear


In [2]:
import pandas as pd
df = pd.read_csv(r"data_baseline.csv")

### CAT analysis:

In [3]:
# Import all necessary packages
import pandas as pd
import numpy as np
import re
# For visualization and future steps
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.experimental import enable_iterative_imputer  # Needed to enable the experimental feature
from sklearn.impute import IterativeImputer

In [4]:
# Check basic structure
print("Shape of dataset:", df.shape)
print("\nSample columns:", df.columns.tolist()[:10])
print("\nMissing values:\n", df.isnull().sum().sort_values(ascending=False).head(10))

# Remove duplicate rows
df = df.drop_duplicates()

# Confirm shape after removing duplicates
print("Shape after removing duplicates:", df.shape)

Shape of dataset: (6125, 465)

Sample columns: ['CIN5', 'StartDatum', 'BEH_MOD', 'BEHDAGEN_GEPLAND', 'AANTAL_PCL', 'TOESTWO', 'BEH_AFG', 'TK', 'MM_CAPS_IN', 'MM_CAPS_TK']

Missing values:
 instrument_SDV_IN    6125
Eaantal_TK           6125
Dcriterium_FU        6125
Cernst_FU            6125
Caantal_FU           6125
Ccriterium_FU        6125
Bernst_FU            6125
Baantal_FU           6125
Bcriterium_FU        6125
Eernst_TK            6125
dtype: int64
Shape after removing duplicates: (6125, 465)


In [5]:
import pyreadstat

# Load gender info
gender_df, meta = pyreadstat.read_sav("SDV_IN_Gender_2019_2024.sav")

# Just extract SDV_SEXE column and append to df
df["SDV_SEXE"] = gender_df["SDV_SEXE"].reset_index(drop=True)

# Optional: map to labels
gender_map = {1.0: "Male", 2.0: "Female", 3.0: "Other"}
df["gender_label"] = df["SDV_SEXE"].map(gender_map)

# Done! Check a sample
print(df[["SDV_SEXE", "gender_label"]].value_counts())

SDV_SEXE  gender_label
2.0       Female          4602
1.0       Male            1475
3.0       Other             48
Name: count, dtype: int64


In [6]:
# Gender dummy variables
df['gender_1'] = (df['gender'] == 1).astype(int)
df['gender_2'] = (df['gender'] == 2).astype(int)

# SDV_SEXE dummy variables
df['SDV_SEXE_1'] = (df['SDV_SEXE'] == 1).astype(int)
df['SDV_SEXE_2'] = (df['SDV_SEXE'] == 2).astype(int)
df['SDV_SEXE_3'] = (df['SDV_SEXE'] == 3).astype(int)

# Create binary columns
df['ethnicity_Dutch'] = np.where(df['ethnicity'] == 1, 1, 0)
df['ethnicity_other'] = np.where(df['ethnicity'] != 1, 1, 0)

In [7]:
# Columns manually identified for removal (example set from the R script)
cols_to_drop = [
    'gender', 'ethnicity', 'CIN5', 'SDV_SEXE', 'StartDatum', 'STARTDATUM', 'DROPOUT_EARLYCOMPLETER', 'TOEST_WO',
    'depressie_IN', 'TERUGKOMER', 'VROEGK_ST', 'gender_label',
    'depr_m_psychose_huid', 'depr_z_psychose_huid', 'depr_z_psychose_verl',
    'depr_m_psychose_verl', 'CAPS5score_followup', 'CAPS5_DAT_IN'
]

df = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
print("Remaining columns:", df.shape[1])

Remaining columns: 458


In [8]:
if 'BEH_DAGEN' in df.columns:
    df.rename(columns={'BEH_DAGEN': 'treatmentdurationdays'}, inplace=True)

In [9]:
# Clean and standardize column names
df.columns = (
    df.columns
    .str.replace(r"\.+", "_", regex=True)
    .str.replace(r"[^a-zA-Z0-9_]", "", regex=True)
    .str.replace(" ", "_")
    .str.strip()
)

In [10]:
# Preview key outcome variables
outcome_vars = ['CAPS5score_baseline', 'CAPS5Score_TK']
for col in outcome_vars:
    if col in df.columns:
        print(f"{col}: {df[col].isnull().sum()} missing")

# Calculate change score
if 'CAPS5score_baseline' in df.columns and 'CAPS5Score_TK' in df.columns:
    df['caps5_change_baseline'] = df['CAPS5Score_TK'] - df['CAPS5score_baseline']

CAPS5score_baseline: 0 missing
CAPS5Score_TK: 0 missing


In [11]:
# Define exceptions to keep
protected_cols = [
    "DIAGNOSIS_ANXIETY_OCD",
    "DIAGNOSIS_PSYCHOTIC",
    "DIAGNOSIS_EATING_DISORDER",
    "DIAGNOSIS_SUBSTANCE_DISORDER", "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", 'SUBCAT_Selectieve_immunosuppresiva', 'treatmentdurationdays',
'SUBCAT_Corticosteroiden',
'SUBCAT_Immunomodulerend_Coxibs',
'SUBCAT_Aminosalicylaten',
'SUBCAT_calcineurineremmers',
'SUBCAT_Anti_epileptica_Benzodiazepine',
'SUBCAT_Paracetamol_overig_combinatie', 'SUBCAT_MAO_remmers', 'SUBCAT_psychostimulans_overige', 'SUBCAT_Interleukine_remmers'

]


# ----------------------------------------
# 1. Drop columns with >95% missing values (except protected)
thresh_missing = int(0.95 * len(df))
missing_cols = [col for col in df.columns if df[col].isnull().sum() > (len(df) - thresh_missing)]
missing_cols_to_drop = [col for col in missing_cols if col not in protected_cols]
df = df.drop(columns=missing_cols_to_drop)

# ----------------------------------------
# 2. Drop near-zero variance columns (except protected)
low_variance_cols = [col for col in df.columns if df[col].nunique(dropna=True) <= 1 and col not in protected_cols]
df = df.drop(columns=low_variance_cols)

In [12]:
df.to_csv("cleaned_data_baseline.csv", index=False)
print(" Step 1 Complete: Cleaned dataset saved.")

 Step 1 Complete: Cleaned dataset saved.


In [13]:
# Target variables:

In [14]:
from sklearn.impute import SimpleImputer
from fancyimpute import IterativeImputer
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer

In [15]:
# Load the cleaned dataset from Step 1
df = pd.read_csv("cleaned_data_baseline.csv")

# Quick check
print(df.shape)
print(df.dtypes.head(10))

(6125, 204)
DIAGNOSIS_ANXIETY_OCD           float64
DIAGNOSIS_SMOKING               float64
DIAGNOSIS_EATING_DISORDER       float64
DIAGNOSIS_SUBSTANCE_DISORDER    float64
DIAGNOSIS_PSYCHOTIC             float64
DIAGNOSIS_SUICIDALITY           float64
DIAGNOSIS_SEXUAL_TRAUMA         float64
DIAGNOSIS_CHILDHOOD_TRAUMA        int64
DIAGNOSIS_CPTSD                 float64
treatmentdurationdays           float64
dtype: object


In [16]:
# Separate Numerical and Categorical Variables

In [17]:
# Identify numerical and categorical columns
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

print(f"Numerical Columns: {len(numerical_cols)}")
print(f"Categorical Columns: {len(categorical_cols)}")

Numerical Columns: 204
Categorical Columns: 0


In [18]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6125 entries, 0 to 6124
Columns: 204 entries, DIAGNOSIS_ANXIETY_OCD to ethnicity_other
dtypes: float64(10), int64(194)
memory usage: 9.5 MB


In [19]:
# Save the fully prepared data
df.to_csv("final_prepared_data.csv", index=False)
print(" Step 2 Complete: Final prepared dataset saved as 'final_prepared_data.csv'.")

 Step 2 Complete: Final prepared dataset saved as 'final_prepared_data.csv'.


In [20]:
import os
import pandas as pd
import numpy as np
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer

# ========== CONFIG ==========
save_folder = "imputed_data"
os.makedirs(save_folder, exist_ok=True)
n_imputations = 5

# ========== LOAD ==========
# Ensure df is already defined
assert 'df' in globals(), "Please load the original DataFrame as `df` before running this script."

# ========== IDENTIFY NUMERIC COLUMNS ==========
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()

# ========== STEP 1: MICE IMPUTATION ==========
print("=" * 50)
print("STEP 1: MICE IMPUTATION")
print("=" * 50)

imputed_dfs = []
for i in range(1, n_imputations + 1):
    print(f"\n=== Running MICE Imputation: Dataset {i} ===")
    #  NEW instance with different seed AND sample_posterior=True for randomness
    mice_imputer = IterativeImputer(
        max_iter=10, 
        random_state=42+i,  # Different base to avoid low numbers
        sample_posterior=True,  #  KEY: This adds randomness!
        n_nearest_features=None,
        initial_strategy='mean'
    )
    # Fit-transform on numeric columns
    imputed_array = mice_imputer.fit_transform(df[numeric_cols])
    # Replace numeric columns in a copy of the original df
    df_imputed = df.copy()
    df_imputed[numeric_cols] = pd.DataFrame(imputed_array, columns=numeric_cols, index=df.index)
    # Append to list
    imputed_dfs.append(df_imputed)
    print(f" Completed imputation {i}")

# ========== STEP 2: ROUNDING ==========
print("\n" + "=" * 50)
print("STEP 2: ROUNDING NUMERIC COLUMNS")
print("=" * 50)

def round_all_numeric_columns_all_imputations(imputed_dfs, decimals=0, verbose=True):
    rounded_dfs = []
    for i, df in enumerate(imputed_dfs):
        df_copy = df.copy()
        numeric_cols = df_copy.select_dtypes(include=[np.number]).columns
        df_copy[numeric_cols] = df_copy[numeric_cols].round(decimals)
        rounded_dfs.append(df_copy)
        if verbose:
            print(f" Imputation {i+1}: Rounded {len(numeric_cols)} numeric columns to {decimals} decimal place(s).")
    return rounded_dfs

# Apply rounding to all imputed datasets
imputed_dfs = round_all_numeric_columns_all_imputations(imputed_dfs)

# ========== STEP 3: SAVE FINAL DATASETS ==========
print("\n" + "=" * 50)
print("STEP 3: SAVING FINAL DATASETS")
print("=" * 50)

for i, df_imputed in enumerate(imputed_dfs, 1):
    # Save outputs
    pkl_path = f"{save_folder}/df_imputed_final_imp{i}.pkl"
    csv_path = f"{save_folder}/df_imputed_final_imp{i}.csv"
    excel_path = f"{save_folder}/df_imputed_final_imp{i}.xlsx"
    
    df_imputed.to_pickle(pkl_path)
    df_imputed.to_csv(csv_path, index=False)
    df_imputed.to_excel(excel_path, index=False)
    
    print(f" Saved files for imputation {i}:")
    print(f"   → {pkl_path}")
    print(f"   → {csv_path}")
    print(f"   → {excel_path}")

# ========== STEP 4: VERIFY DATASETS ARE DIFFERENT ==========
print("\n" + "=" * 50)
print("STEP 4: VERIFYING DATASET DIFFERENCES")
print("=" * 50)

def check_imputation_differences(imputed_dfs, verbose=True):
    """Check if imputed datasets are actually different from each other"""
    if len(imputed_dfs) < 2:
        print("  Only one dataset - cannot check differences")
        return
    
    # Get numeric columns that had missing values originally
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    missing_cols = [col for col in numeric_cols if df[col].isnull().any()]
    
    if not missing_cols:
        print("  No missing values found in original data")
        return
    
    print(f" Checking differences in {len(missing_cols)} columns that had missing values...")
    
    differences_found = False
    
    for col in missing_cols[:3]:  # Check first 3 columns with missing values
        # Compare first two datasets for this column
        values_1 = imputed_dfs[0][col].values
        values_2 = imputed_dfs[1][col].values
        
        if not np.array_equal(values_1, values_2):
            differences_found = True
            # Count how many values are different
            diff_count = np.sum(values_1 != values_2)
            print(f" Column '{col}': {diff_count} different values between datasets 1 & 2")
        else:
            print(f" Column '{col}': IDENTICAL values between datasets 1 & 2")
    
    if differences_found:
        print(f"\n SUCCESS: Datasets show proper variability!")
    else:
        print(f"\n  WARNING: Datasets appear identical - check random_state implementation")
    
    return differences_found

# Run the check
check_imputation_differences(imputed_dfs)

print("\n" + "=" * 50)
print(" MICE IMPUTATION COMPLETE!")
print("=" * 50)
print(f" Created {n_imputations} imputed datasets")
print(f" Applied rounding to all numeric columns")
print(f" Saved files in: {save_folder}/")
print("=" * 50)

STEP 1: MICE IMPUTATION

=== Running MICE Imputation: Dataset 1 ===
 Completed imputation 1

=== Running MICE Imputation: Dataset 2 ===
 Completed imputation 2

=== Running MICE Imputation: Dataset 3 ===
 Completed imputation 3

=== Running MICE Imputation: Dataset 4 ===
 Completed imputation 4

=== Running MICE Imputation: Dataset 5 ===
 Completed imputation 5

STEP 2: ROUNDING NUMERIC COLUMNS
 Imputation 1: Rounded 204 numeric columns to 0 decimal place(s).
 Imputation 2: Rounded 204 numeric columns to 0 decimal place(s).
 Imputation 3: Rounded 204 numeric columns to 0 decimal place(s).
 Imputation 4: Rounded 204 numeric columns to 0 decimal place(s).
 Imputation 5: Rounded 204 numeric columns to 0 decimal place(s).

STEP 3: SAVING FINAL DATASETS
 Saved files for imputation 1:
   → imputed_data/df_imputed_final_imp1.pkl
   → imputed_data/df_imputed_final_imp1.csv
   → imputed_data/df_imputed_final_imp1.xlsx
 Saved files for imputation 2:
   → imputed_data/df_imputed_final_imp2.pkl
  

In [21]:
import pandas as pd
import numpy as np

# ========== METHOD 1: QUICK CHECK - Compare first 2 datasets ==========
def quick_difference_check(imputed_dfs):
    """Quick check to see if first two datasets are different"""
    if len(imputed_dfs) < 2:
        print("Need at least 2 datasets to compare")
        return
    
    df1 = imputed_dfs[0]
    df2 = imputed_dfs[1]
    
    # Check if dataframes are identical
    are_identical = df1.equals(df2)
    print(f"Dataset 1 vs Dataset 2: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")
    
    if not are_identical:
        # Count different values
        numeric_cols = df1.select_dtypes(include=[np.number]).columns
        total_diff = 0
        for col in numeric_cols:
            diff_count = np.sum(df1[col] != df2[col])
            if diff_count > 0:
                total_diff += diff_count
                print(f"  '{col}': {diff_count} different values")
        print(f"  Total different values: {total_diff}")

# ========== METHOD 2: DETAILED CHECK - All pairwise comparisons ==========
def detailed_difference_check(imputed_dfs):
    """Check differences between all pairs of datasets"""
    n_datasets = len(imputed_dfs)
    print(f"\n=== Checking all {n_datasets} datasets ===")
    
    numeric_cols = imputed_dfs[0].select_dtypes(include=[np.number]).columns
    
    for i in range(n_datasets):
        for j in range(i+1, n_datasets):
            are_identical = imputed_dfs[i].equals(imputed_dfs[j])
            print(f"Dataset {i+1} vs Dataset {j+1}: {'IDENTICAL ❌' if are_identical else 'DIFFERENT ✅'}")

# ========== METHOD 3: FOCUS ON ORIGINALLY MISSING VALUES ==========
def check_missing_value_differences(original_df, imputed_dfs):
    """Check differences only in originally missing positions"""
    print(f"\n=== Checking differences in originally missing positions ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    differences_found = False
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_mask = original_df[col].isnull()
            print(f"\nColumn '{col}' ({missing_mask.sum()} missing values):")
            
            # Compare imputed values at missing positions
            for i in range(len(imputed_dfs)-1):
                imp1_values = imputed_dfs[i].loc[missing_mask, col]
                imp2_values = imputed_dfs[i+1].loc[missing_mask, col]
                
                are_same = np.array_equal(imp1_values.values, imp2_values.values)
                if not are_same:
                    differences_found = True
                    diff_count = np.sum(imp1_values.values != imp2_values.values)
                    print(f"  Dataset {i+1} vs {i+2}: {diff_count}/{len(imp1_values)} different imputed values ✅")
                else:
                    print(f"  Dataset {i+1} vs {i+2}: IDENTICAL imputed values ❌")
    
    return differences_found

# ========== METHOD 4: SAMPLE VALUES FROM EACH DATASET ==========
def show_sample_imputed_values(original_df, imputed_dfs, n_samples=5):
    """Show sample imputed values from each dataset"""
    print(f"\n=== Sample imputed values (first {n_samples} missing positions) ===")
    
    numeric_cols = original_df.select_dtypes(include=[np.number]).columns
    
    for col in numeric_cols:
        if original_df[col].isnull().any():
            missing_positions = original_df[original_df[col].isnull()].index[:n_samples]
            
            print(f"\nColumn '{col}' at positions {list(missing_positions)}:")
            for i, df_imp in enumerate(imputed_dfs):
                values = df_imp.loc[missing_positions, col].values
                print(f"  Dataset {i+1}: {values}")

# ========== RUN ALL CHECKS ==========
print("=" * 60)
print("CHECKING IMPUTATION DIFFERENCES")
print("=" * 60)

# Method 1: Quick check
quick_difference_check(imputed_dfs)

# Method 2: All pairwise comparisons  
detailed_difference_check(imputed_dfs)

# Method 3: Focus on originally missing values (assumes 'df' is your original dataframe)
if 'df' in globals():
    differences_found = check_missing_value_differences(df, imputed_dfs)
    if differences_found:
        print(f"\n🎉 SUCCESS: Found differences in imputed values!")
    else:
        print(f"\n⚠️ WARNING: No differences found in imputed values!")

# Method 4: Show sample values
if 'df' in globals():
    show_sample_imputed_values(df, imputed_dfs, n_samples=3)

CHECKING IMPUTATION DIFFERENCES
Dataset 1 vs Dataset 2: DIFFERENT ✅
  'DIAGNOSIS_ANXIETY_OCD': 694 different values
  'DIAGNOSIS_SMOKING': 50 different values
  'DIAGNOSIS_SUBSTANCE_DISORDER': 240 different values
  'DIAGNOSIS_PSYCHOTIC': 164 different values
  'DIAGNOSIS_SUICIDALITY': 9 different values
  'DIAGNOSIS_SEXUAL_TRAUMA': 15 different values
  'DIAGNOSIS_CPTSD': 36 different values
  'treatmentdurationdays': 1170 different values
  'Bipolar_and_Mood_disorder': 928 different values
  Total different values: 3306

=== Checking all 5 datasets ===
Dataset 1 vs Dataset 2: DIFFERENT ✅
Dataset 1 vs Dataset 3: DIFFERENT ✅
Dataset 1 vs Dataset 4: DIFFERENT ✅
Dataset 1 vs Dataset 5: DIFFERENT ✅
Dataset 2 vs Dataset 3: DIFFERENT ✅
Dataset 2 vs Dataset 4: DIFFERENT ✅
Dataset 2 vs Dataset 5: DIFFERENT ✅
Dataset 3 vs Dataset 4: DIFFERENT ✅
Dataset 3 vs Dataset 5: DIFFERENT ✅
Dataset 4 vs Dataset 5: DIFFERENT ✅

=== Checking differences in originally missing positions ===

Column 'DIAGNOSI

In [22]:
imputed_folder = "imputed_data"
n_imputations = 5

# Lists to hold DataFrames and Y vectors
imputed_dfs = []
Y_list = []

for i in range(1, n_imputations + 1):
    file_path = f"{imputed_folder}/df_imputed_final_imp{i}.pkl"
    
    # Load imputed DataFrame
    df_imp = pd.read_pickle(file_path)
    imputed_dfs.append(df_imp)

    # Define Y for this imputation
    Y = df_imp["caps5_change_baseline"]
    Y_list.append(Y)

    print(f"Y for imputation {i} defined. Sample values:")
    print(Y.head())

Y for imputation 1 defined. Sample values:
0   -41.0
1   -15.0
2   -46.0
3   -41.0
4   -20.0
Name: caps5_change_baseline, dtype: float64
Y for imputation 2 defined. Sample values:
0   -41.0
1   -15.0
2   -46.0
3   -41.0
4   -20.0
Name: caps5_change_baseline, dtype: float64
Y for imputation 3 defined. Sample values:
0   -41.0
1   -15.0
2   -46.0
3   -41.0
4   -20.0
Name: caps5_change_baseline, dtype: float64
Y for imputation 4 defined. Sample values:
0   -41.0
1   -15.0
2   -46.0
3   -41.0
4   -20.0
Name: caps5_change_baseline, dtype: float64
Y for imputation 5 defined. Sample values:
0   -41.0
1   -15.0
2   -46.0
3   -41.0
4   -20.0
Name: caps5_change_baseline, dtype: float64


In [23]:
import pandas as pd

# Load imputed DataFrames from saved files
imputed_folder = "imputed_data"
n_imputations = 5

for i in range(1, n_imputations + 1):
    print(f"\n=== Imputed Dataset {i} ===")

    # Load each imputed dataset
    df_imp = pd.read_pickle(f"{imputed_folder}/df_imputed_final_imp{i}.pkl")

    # Get all CAT_* columns
    cat_columns = [col for col in df_imp.columns if col.startswith('CAT_')]

    print("Medication Groups:")
    print(cat_columns)
    print("Total Medication Groups Found:", len(cat_columns))


=== Imputed Dataset 1 ===
Medication Groups:
['CAT_Antidepressiva', 'CAT_Benzodiazepine', 'CAT_Anti_epileptica', 'CAT_Antihistaminica', 'CAT_Opioden', 'CAT_Antipsychotica', 'CAT_Aceetanilidederivaten', 'CAT_Antihypertensiva', 'CAT_Salicylaat', 'CAT_NSAIDs', 'CAT_Migrainemiddelen', 'CAT_ADHD', 'CAT_Anticonceptiva', 'CAT_Z_drugs', 'CAT_Spierrelaxantia', 'CAT_Immunomodulerende_middelen', 'CAT_Alcoholverslaving', 'CAT_Stemmingsstabilisatoren', 'CAT_Parkinson', 'CAT_ALL', 'CAT_ALL_PSYCHOTROPICS', 'CAT_ALL_PSYCHOTROPICS_EXCL_BENZO', 'CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS']
Total Medication Groups Found: 23

=== Imputed Dataset 2 ===
Medication Groups:
['CAT_Antidepressiva', 'CAT_Benzodiazepine', 'CAT_Anti_epileptica', 'CAT_Antihistaminica', 'CAT_Opioden', 'CAT_Antipsychotica', 'CAT_Aceetanilidederivaten', 'CAT_Antihypertensiva', 'CAT_Salicylaat', 'CAT_NSAIDs', 'CAT_Migrainemiddelen', 'CAT_ADHD', 'CAT_Anticonceptiva', 'CAT_Z_drugs', 'CAT_Spierrelaxantia', 'CAT_Immunomodulerende_midd

In [24]:
covariates_CAT_ADHD = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'CAT_Z_drugs', 'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Aceetanilidederivaten = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_CAT_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihypertensiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder"
]


covariates_CAT_Anti_epileptica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_Antipsychotica = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Anti_epileptica',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine', 'CAT_Opioden', 'CAT_Z_drugs',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_BENZO = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Benzodiazepine', 'CAT_Z_drugs', 'CAT_Anticonceptiva',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL_PSYCHOTROPICS = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_CAT_ALL = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'DIAGNOSIS_CHILDHOOD_TRAUMA',
    'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SMOKING', 'DIAGNOSIS_SUICIDALITY',
    'age', 'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "EB_NON_TF_THERAPY", "OTHER_TREATM_APPROACH", "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [25]:
from collections import defaultdict

# This finds all variables that start with covariates_CAT_ or covariates_cat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_cat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

Groups found: ['CAT_ADHD', 'CAT_Aceetanilidederivaten', 'CAT_Z_drugs', 'CAT_Opioden', 'CAT_NSAIDs', 'CAT_Benzodiazepine', 'CAT_Antihypertensiva', 'CAT_Antihistaminica', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica', 'CAT_ALL_PSYCHOTROPICS_EXCL_BENZO', 'CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS', 'CAT_ALL_PSYCHOTROPICS', 'CAT_ALL']
['CAT_ADHD', 'CAT_Aceetanilidederivaten', 'CAT_Z_drugs', 'CAT_Opioden', 'CAT_NSAIDs', 'CAT_Benzodiazepine', 'CAT_Antihypertensiva', 'CAT_Antihistaminica', 'CAT_Anti_epileptica', 'CAT_Antidepressiva', 'CAT_Antipsychotica', 'CAT_ALL_PSYCHOTROPICS_EXCL_BENZO', 'CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS', 'CAT_ALL_PSYCHOTROPICS', 'CAT_ALL']


In [26]:
import os


def run_all_CAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each CAT medication group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_cat_<group>
    - Outputs are saved in: outputs/CAT_<GROUP>/
    """

    print(" Starting analysis for all CAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_cat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., cat_z_drugs → Cat_Z_Drugs
            group_name = group_name.replace("Cat_", "CAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All CAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_CAT_group_models(imputed_dfs)

 Starting analysis for all CAT groups

 Processing CAT_Adhd...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: CAT_Adhd

 Processing CAT_Aceetanilidederivaten...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: CAT_Aceetanilidederivaten

 Processing CAT_Z_Drugs...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: CAT_Z_Drugs

 Processing CAT_Opioden...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: CAT_Opioden

 Processing CAT_Nsaids...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: CAT_Nsaids

 Processing CAT_Benzodiazepine...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → U

In [27]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")


 CAT_ADHD
  Imp 1: Treated = 88, Control = 6037, Missing = 0
  Imp 2: Treated = 88, Control = 6037, Missing = 0
  Imp 3: Treated = 88, Control = 6037, Missing = 0
  Imp 4: Treated = 88, Control = 6037, Missing = 0
  Imp 5: Treated = 88, Control = 6037, Missing = 0

 CAT_Aceetanilidederivaten
  Imp 1: Treated = 77, Control = 6048, Missing = 0
  Imp 2: Treated = 77, Control = 6048, Missing = 0
  Imp 3: Treated = 77, Control = 6048, Missing = 0
  Imp 4: Treated = 77, Control = 6048, Missing = 0
  Imp 5: Treated = 77, Control = 6048, Missing = 0

 CAT_Z_drugs
  Imp 1: Treated = 89, Control = 6036, Missing = 0
  Imp 2: Treated = 89, Control = 6036, Missing = 0
  Imp 3: Treated = 89, Control = 6036, Missing = 0
  Imp 4: Treated = 89, Control = 6036, Missing = 0
  Imp 5: Treated = 89, Control = 6036, Missing = 0

 CAT_Opioden
  Imp 1: Treated = 79, Control = 6046, Missing = 0
  Imp 2: Treated = 79, Control = 6046, Missing = 0
  Imp 3: Treated = 79, Control = 6046, Missing = 0
  Imp 4: Treate

In [28]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")


🔍 Processing VIF for CAT_ADHD
 ✅ Saved: outputs\CAT_ADHD/pooled_vif.csv

🔍 Processing VIF for CAT_Aceetanilidederivaten
 ✅ Saved: outputs\CAT_Aceetanilidederivaten/pooled_vif.csv

🔍 Processing VIF for CAT_Z_drugs
 ✅ Saved: outputs\CAT_Z_drugs/pooled_vif.csv

🔍 Processing VIF for CAT_Opioden
 ✅ Saved: outputs\CAT_Opioden/pooled_vif.csv

🔍 Processing VIF for CAT_NSAIDs
 ✅ Saved: outputs\CAT_NSAIDs/pooled_vif.csv

🔍 Processing VIF for CAT_Benzodiazepine
 ✅ Saved: outputs\CAT_Benzodiazepine/pooled_vif.csv

🔍 Processing VIF for CAT_Antihypertensiva
 ✅ Saved: outputs\CAT_Antihypertensiva/pooled_vif.csv

🔍 Processing VIF for CAT_Antihistaminica
 ✅ Saved: outputs\CAT_Antihistaminica/pooled_vif.csv

🔍 Processing VIF for CAT_Anti_epileptica
 ✅ Saved: outputs\CAT_Anti_epileptica/pooled_vif.csv

🔍 Processing VIF for CAT_Antidepressiva
 ✅ Saved: outputs\CAT_Antidepressiva/pooled_vif.csv

🔍 Processing VIF for CAT_Antipsychotica
 ✅ Saved: outputs\CAT_Antipsychotica/pooled_vif.csv

🔍 Processing VIF f

In [29]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

 Running PS estimation for CAT_ADHD
   Imp 1: AUC = 0.722, ROC saved.
   Imp 2: AUC = 0.717, ROC saved.
   Imp 3: AUC = 0.710, ROC saved.
   Imp 4: AUC = 0.730, ROC saved.
   Imp 5: AUC = 0.732, ROC saved.
 Composite PS + AUC saved for CAT_ADHD
 Running PS estimation for CAT_Aceetanilidederivaten
   Imp 1: AUC = 0.858, ROC saved.
   Imp 2: AUC = 0.868, ROC saved.
   Imp 3: AUC = 0.861, ROC saved.
   Imp 4: AUC = 0.862, ROC saved.
   Imp 5: AUC = 0.879, ROC saved.
 Composite PS + AUC saved for CAT_Aceetanilidederivaten
 Running PS estimation for CAT_Z_drugs
   Imp 1: AUC = 0.893, ROC saved.
   Imp 2: AUC = 0.886, ROC saved.
   Imp 3: AUC = 0.890, ROC saved.
   Imp 4: AUC = 0.885, ROC saved.
   Imp 5: AUC = 0.883, ROC saved.
 Composite PS + AUC saved for CAT_Z_drugs
 Running PS estimation for CAT_Opioden
   Imp 1: AUC = 0.832, ROC saved.
   Imp 2: AUC = 0.842, ROC saved.
   Imp 3: AUC = 0.842, ROC saved.
   Imp 4: AUC = 0.857, ROC saved.
   Imp 5: AUC = 0.839, ROC saved.
 Composite PS + 

In [30]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X, T)

                # Get feature importance (absolute coefficients)
                importances = np.abs(model.coef_[0])
                importance_dict = dict(zip(X.columns, importances))
                df_feat = pd.DataFrame.from_dict(importance_dict, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)


 Computing feature importance for CAT_ADHD
 Saved feature importance plot and CSV for CAT_ADHD

 Computing feature importance for CAT_Aceetanilidederivaten
 Saved feature importance plot and CSV for CAT_Aceetanilidederivaten

 Computing feature importance for CAT_Z_drugs
 Saved feature importance plot and CSV for CAT_Z_drugs

 Computing feature importance for CAT_Opioden
 Saved feature importance plot and CSV for CAT_Opioden

 Computing feature importance for CAT_NSAIDs
 Saved feature importance plot and CSV for CAT_NSAIDs

 Computing feature importance for CAT_Benzodiazepine
 Saved feature importance plot and CSV for CAT_Benzodiazepine

 Computing feature importance for CAT_Antihypertensiva
 Saved feature importance plot and CSV for CAT_Antihypertensiva

 Computing feature importance for CAT_Antihistaminica
 Saved feature importance plot and CSV for CAT_Antihistaminica

 Computing feature importance for CAT_Anti_epileptica
 Saved feature importance plot and CSV for CAT_Anti_epileptic

In [31]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)


🔍 Processing IPTW + trimming + clipping for CAT_ADHD
✅ Saved IPTW weights for CAT_ADHD
    ℹ️ Retained 392/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_ADHD/trimmed_data_imp1.*
    ℹ️ Retained 383/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_ADHD/trimmed_data_imp2.*
    ℹ️ Retained 395/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_ADHD/trimmed_data_imp3.*
    ℹ️ Retained 388/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_ADHD/trimmed_data_imp4.*
    ℹ️ Retained 404/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_ADHD/trimmed_data_imp5.*

🔍 Processing IPTW + trimming + clipping for CAT_Aceetanilidederivaten
✅ Saved IPTW weights for CAT_Aceetanilidederivaten
    ℹ️ Retained 176/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\CAT_Aceetanilidederivaten/trimmed_data_imp1.*
    ℹ️ Retained 181/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outpu

In [32]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)


📊 Plotting PS overlap for CAT_ADHD
✅ Saved unweighted and weighted PS plots for CAT_ADHD

📊 Plotting PS overlap for CAT_Aceetanilidederivaten
✅ Saved unweighted and weighted PS plots for CAT_Aceetanilidederivaten

📊 Plotting PS overlap for CAT_Z_drugs
✅ Saved unweighted and weighted PS plots for CAT_Z_drugs

📊 Plotting PS overlap for CAT_Opioden
✅ Saved unweighted and weighted PS plots for CAT_Opioden

📊 Plotting PS overlap for CAT_NSAIDs
✅ Saved unweighted and weighted PS plots for CAT_NSAIDs

📊 Plotting PS overlap for CAT_Benzodiazepine
✅ Saved unweighted and weighted PS plots for CAT_Benzodiazepine

📊 Plotting PS overlap for CAT_Antihypertensiva
✅ Saved unweighted and weighted PS plots for CAT_Antihypertensiva

📊 Plotting PS overlap for CAT_Antihistaminica
✅ Saved unweighted and weighted PS plots for CAT_Antihistaminica

📊 Plotting PS overlap for CAT_Anti_epileptica
✅ Saved unweighted and weighted PS plots for CAT_Anti_epileptica

📊 Plotting PS overlap for CAT_Antidepressiva
✅ Save

In [33]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

 ✅ Saved: outputs\CAT_ADHD\four_panel_overlap_CAT_ADHD.png
 ✅ Saved: outputs\CAT_Aceetanilidederivaten\four_panel_overlap_CAT_Aceetanilidederivaten.png
 ✅ Saved: outputs\CAT_Z_drugs\four_panel_overlap_CAT_Z_drugs.png
 ✅ Saved: outputs\CAT_Opioden\four_panel_overlap_CAT_Opioden.png
 ✅ Saved: outputs\CAT_NSAIDs\four_panel_overlap_CAT_NSAIDs.png
 ✅ Saved: outputs\CAT_Benzodiazepine\four_panel_overlap_CAT_Benzodiazepine.png
 ✅ Saved: outputs\CAT_Antihypertensiva\four_panel_overlap_CAT_Antihypertensiva.png
 ✅ Saved: outputs\CAT_Antihistaminica\four_panel_overlap_CAT_Antihistaminica.png
 ✅ Saved: outputs\CAT_Anti_epileptica\four_panel_overlap_CAT_Anti_epileptica.png
 ✅ Saved: outputs\CAT_Antidepressiva\four_panel_overlap_CAT_Antidepressiva.png
 ✅ Saved: outputs\CAT_Antipsychotica\four_panel_overlap_CAT_Antipsychotica.png
 ✅ Saved: outputs\CAT_ALL_PSYCHOTROPICS_EXCL_BENZO\four_panel_overlap_CAT_ALL_PSYCHOTROPICS_EXCL_BENZO.png
 ✅ Saved: outputs\CAT_ALL_PSYCHOTROPICS_EXCL_SEDATIVES_HYPNOTICS\f

In [34]:
# ATT calculation:

In [35]:
# Weighted:

In [39]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit weighted OLS with robust standard errors
                    ols_model = sm.WLS(Y, X_ols, weights=W).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T, W)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_rubin_summary_cats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)


🚀 Running OLS for CAT_ADHD
✅ CAT_ADHD | Seed 1: ATT = 0.8360, SE = 2.4919, p = 0.75412
✅ CAT_ADHD | Seed 2: ATT = 1.2408, SE = 2.4817, p = 0.64335
✅ CAT_ADHD | Seed 3: ATT = -0.4415, SE = 3.3983, p = 0.90290
✅ CAT_ADHD | Seed 4: ATT = 1.2011, SE = 3.9538, p = 0.77644
✅ CAT_ADHD | Seed 5: ATT = 1.3198, SE = 2.6504, p = 0.64464
✅ CAT_ADHD | Seed 6: ATT = -1.1362, SE = 1.7887, p = 0.55982
✅ CAT_ADHD | Seed 7: ATT = 0.8207, SE = 2.8346, p = 0.78656
✅ CAT_ADHD | Seed 8: ATT = 0.1788, SE = 2.4396, p = 0.94509
✅ CAT_ADHD | Seed 9: ATT = -0.2881, SE = 3.3107, p = 0.93483
✅ CAT_ADHD | Seed 10: ATT = 0.2123, SE = 3.1963, p = 0.95023
📊 Diagnostic plots saved for CAT_ADHD
🏆 Best result for CAT_ADHD → Seed 6 | SE = 1.7887

🚀 Running OLS for CAT_Aceetanilidederivaten
✅ CAT_Aceetanilidederivaten | Seed 1: ATT = 1.3972, SE = 6.8213, p = 0.84771
✅ CAT_Aceetanilidederivaten | Seed 2: ATT = 2.2173, SE = 5.4355, p = 0.70422
✅ CAT_Aceetanilidederivaten | Seed 3: ATT = 2.8694, SE = 6.5133, p = 0.68230
✅ CA

In [40]:
# Unweighted:

In [41]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1 = treated[col].std()
            s0 = control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                #W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit OLS with robust standard errors (unweighted)
                    ols_model = sm.OLS(Y, X_ols).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_rubin_summary_cats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_cats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)


🚀 Running OLS for CAT_ADHD
✅ CAT_ADHD | Seed 1: ATT = 0.8977, SE = 3.1989, p = 0.79292
✅ CAT_ADHD | Seed 2: ATT = 0.6329, SE = 2.8531, p = 0.83530
✅ CAT_ADHD | Seed 3: ATT = -0.1555, SE = 3.5564, p = 0.96721
✅ CAT_ADHD | Seed 4: ATT = 0.8576, SE = 4.7931, p = 0.86670
✅ CAT_ADHD | Seed 5: ATT = 0.7333, SE = 2.6519, p = 0.79584
✅ CAT_ADHD | Seed 6: ATT = -1.4120, SE = 2.4237, p = 0.59144
✅ CAT_ADHD | Seed 7: ATT = 1.2545, SE = 2.9350, p = 0.69107
✅ CAT_ADHD | Seed 8: ATT = 0.0107, SE = 2.8703, p = 0.99721
✅ CAT_ADHD | Seed 9: ATT = -0.1739, SE = 3.5807, p = 0.96359
✅ CAT_ADHD | Seed 10: ATT = -0.0825, SE = 3.9422, p = 0.98431
📊 Diagnostic plots saved for CAT_ADHD
🏆 Best result for CAT_ADHD → Seed 6 | SE = 2.4237

🚀 Running OLS for CAT_Aceetanilidederivaten
✅ CAT_Aceetanilidederivaten | Seed 1: ATT = 1.7266, SE = 6.9697, p = 0.81654
✅ CAT_Aceetanilidederivaten | Seed 2: ATT = 2.7754, SE = 4.4979, p = 0.57060
✅ CAT_Aceetanilidederivaten | Seed 3: ATT = 3.0011, SE = 6.7271, p = 0.67860
✅ C

In [42]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "ols_rubin_summary_cats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: ols_rubin_summary_cats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_Cat.xlsx", index=False)
print("✅ Final_ATT_Summary_Cat saved")

✅ Final_ATT_Summary_Cat saved


In [43]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_Cat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("ols_att_barplot_cat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ ols_att_barplot_cat saved")

✅ ols_att_barplot_cat saved


In [44]:
# Love plot:

In [45]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")


🔍 Processing CAT_Aceetanilidederivaten...
📊 Exported numeric summary to: outputs\CAT_Aceetanilidederivaten\covariate_balance_table_CAT_Aceetanilidederivaten.xlsx
✅ Saved love plot: outputs\CAT_Aceetanilidederivaten\love_plot_CAT_Aceetanilidederivaten.pdf
📏 Max weighted SMD for CAT_Aceetanilidederivaten: 0.536

🔍 Processing CAT_Adhd...
📊 Exported numeric summary to: outputs\CAT_Adhd\covariate_balance_table_CAT_Adhd.xlsx
✅ Saved love plot: outputs\CAT_Adhd\love_plot_CAT_Adhd.pdf
📏 Max weighted SMD for CAT_Adhd: 0.433

🔍 Processing CAT_All...
📊 Exported numeric summary to: outputs\CAT_All\covariate_balance_table_CAT_All.xlsx
✅ Saved love plot: outputs\CAT_All\love_plot_CAT_All.pdf
📏 Max weighted SMD for CAT_All: 0.100

🔍 Processing CAT_All_Psychotropics...
📊 Exported numeric summary to: outputs\CAT_All_Psychotropics\covariate_balance_table_CAT_All_Psychotropics.xlsx
✅ Saved love plot: outputs\CAT_All_Psychotropics\love_plot_CAT_All_Psychotropics.pdf
📏 Max weighted SMD for CAT_All_Psychot

In [46]:
# Heatmap:

In [47]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")


✅ Heatmap saved: outputs\CAT_ADHD\heatmap_smd_CAT_ADHD.png

✅ Heatmap saved: outputs\CAT_Aceetanilidederivaten\heatmap_smd_CAT_Aceetanilidederivaten.png

✅ Heatmap saved: outputs\CAT_Z_drugs\heatmap_smd_CAT_Z_drugs.png

✅ Heatmap saved: outputs\CAT_Opioden\heatmap_smd_CAT_Opioden.png

✅ Heatmap saved: outputs\CAT_NSAIDs\heatmap_smd_CAT_NSAIDs.png

✅ Heatmap saved: outputs\CAT_Benzodiazepine\heatmap_smd_CAT_Benzodiazepine.png

✅ Heatmap saved: outputs\CAT_Antihypertensiva\heatmap_smd_CAT_Antihypertensiva.png

✅ Heatmap saved: outputs\CAT_Antihistaminica\heatmap_smd_CAT_Antihistaminica.png

✅ Heatmap saved: outputs\CAT_Anti_epileptica\heatmap_smd_CAT_Anti_epileptica.png

✅ Heatmap saved: outputs\CAT_Antidepressiva\heatmap_smd_CAT_Antidepressiva.png

✅ Heatmap saved: outputs\CAT_Antipsychotica\heatmap_smd_CAT_Antipsychotica.png

✅ Heatmap saved: outputs\CAT_ALL_PSYCHOTROPICS_EXCL_BENZO\heatmap_smd_CAT_ALL_PSYCHOTROPICS_EXCL_BENZO.png

✅ Heatmap saved: outputs\CAT_ALL_PSYCHOTROPICS_EXCL_S

### Subcat analysis:

In [48]:
covariates_SUBCAT_Antipsychotica_atypisch = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_TCA = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SSRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_SNRI = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Tetracyclische_antidepressiva = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Antidepressiva_overige = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Systemische_antihistaminica = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_anxiolytica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_hypnotica_Benzodiazepine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Amfetaminen = [
    'treatmentdurationdays', 'CAPS5score_baseline', 'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Systemische_betablokkers = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Paracetamol_mono = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SUBCAT_Anti_epileptica_stemmingsstabilisatoren = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age', 
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Opioden = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_Z_drugs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SUBCAT_NSAIDs = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [49]:
from collections import defaultdict

# This finds all variables that start with covariates_SUBCAT_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

Groups found: ['SUBCAT_Antipsychotica_atypisch', 'SUBCAT_TCA', 'SUBCAT_SSRI', 'SUBCAT_SNRI', 'SUBCAT_Tetracyclische_antidepressiva', 'SUBCAT_Antidepressiva_overige', 'SUBCAT_Systemische_antihistaminica', 'SUBCAT_anxiolytica_Benzodiazepine', 'SUBCAT_hypnotica_Benzodiazepine', 'SUBCAT_Amfetaminen', 'SUBCAT_Systemische_betablokkers', 'SUBCAT_Paracetamol_mono', 'SUBCAT_Anti_epileptica_stemmingsstabilisatoren', 'SUBCAT_Opioden', 'SUBCAT_Z_drugs', 'SUBCAT_NSAIDs']
['SUBCAT_Antipsychotica_atypisch', 'SUBCAT_TCA', 'SUBCAT_SSRI', 'SUBCAT_SNRI', 'SUBCAT_Tetracyclische_antidepressiva', 'SUBCAT_Antidepressiva_overige', 'SUBCAT_Systemische_antihistaminica', 'SUBCAT_anxiolytica_Benzodiazepine', 'SUBCAT_hypnotica_Benzodiazepine', 'SUBCAT_Amfetaminen', 'SUBCAT_Systemische_betablokkers', 'SUBCAT_Paracetamol_mono', 'SUBCAT_Anti_epileptica_stemmingsstabilisatoren', 'SUBCAT_Opioden', 'SUBCAT_Z_drugs', 'SUBCAT_NSAIDs']


In [50]:
import os


def run_all_SUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBCAT medisubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subcat_<group>
    - Outputs are saved in: outputs/SUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subcat_z_drugs → Subcat_Z_Drugs
            group_name = group_name.replace("Subcat_", "SUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBCAT_group_models(imputed_dfs)

 Starting analysis for all SUBCAT groups

 Processing SUBCAT_Antipsychotica_Atypisch...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBCAT_Antipsychotica_Atypisch

 Processing SUBCAT_Tca...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBCAT_Tca

 Processing SUBCAT_Ssri...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBCAT_Ssri

 Processing SUBCAT_Snri...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBCAT_Snri

 Processing SUBCAT_Tetracyclische_Antidepressiva...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBCAT_Tetracyclische_Antidepressiva

 Processing SUBCAT_Antidepressiva_Overige...
  → Using imputat

In [51]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")


 SUBCAT_Antipsychotica_atypisch
  Imp 1: Treated = 355, Control = 5770, Missing = 0
  Imp 2: Treated = 355, Control = 5770, Missing = 0
  Imp 3: Treated = 355, Control = 5770, Missing = 0
  Imp 4: Treated = 355, Control = 5770, Missing = 0
  Imp 5: Treated = 355, Control = 5770, Missing = 0

 SUBCAT_TCA
  Imp 1: Treated = 117, Control = 6008, Missing = 0
  Imp 2: Treated = 117, Control = 6008, Missing = 0
  Imp 3: Treated = 117, Control = 6008, Missing = 0
  Imp 4: Treated = 117, Control = 6008, Missing = 0
  Imp 5: Treated = 117, Control = 6008, Missing = 0

 SUBCAT_SSRI
  Imp 1: Treated = 555, Control = 5570, Missing = 0
  Imp 2: Treated = 555, Control = 5570, Missing = 0
  Imp 3: Treated = 555, Control = 5570, Missing = 0
  Imp 4: Treated = 555, Control = 5570, Missing = 0
  Imp 5: Treated = 555, Control = 5570, Missing = 0

 SUBCAT_SNRI
  Imp 1: Treated = 106, Control = 6019, Missing = 0
  Imp 2: Treated = 106, Control = 6019, Missing = 0
  Imp 3: Treated = 106, Control = 6019, Mi

In [52]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")


🔍 Processing VIF for SUBCAT_Antipsychotica_atypisch
 ✅ Saved: outputs\SUBCAT_Antipsychotica_atypisch/pooled_vif.csv

🔍 Processing VIF for SUBCAT_TCA
 ✅ Saved: outputs\SUBCAT_TCA/pooled_vif.csv

🔍 Processing VIF for SUBCAT_SSRI
 ✅ Saved: outputs\SUBCAT_SSRI/pooled_vif.csv

🔍 Processing VIF for SUBCAT_SNRI
 ✅ Saved: outputs\SUBCAT_SNRI/pooled_vif.csv

🔍 Processing VIF for SUBCAT_Tetracyclische_antidepressiva
 ✅ Saved: outputs\SUBCAT_Tetracyclische_antidepressiva/pooled_vif.csv

🔍 Processing VIF for SUBCAT_Antidepressiva_overige
 ✅ Saved: outputs\SUBCAT_Antidepressiva_overige/pooled_vif.csv

🔍 Processing VIF for SUBCAT_Systemische_antihistaminica
 ✅ Saved: outputs\SUBCAT_Systemische_antihistaminica/pooled_vif.csv

🔍 Processing VIF for SUBCAT_anxiolytica_Benzodiazepine
 ✅ Saved: outputs\SUBCAT_anxiolytica_Benzodiazepine/pooled_vif.csv

🔍 Processing VIF for SUBCAT_hypnotica_Benzodiazepine
 ✅ Saved: outputs\SUBCAT_hypnotica_Benzodiazepine/pooled_vif.csv

🔍 Processing VIF for SUBCAT_Amfetami

In [53]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)

 Running PS estimation for SUBCAT_Antipsychotica_atypisch
   Imp 1: AUC = 0.998, ROC saved.
   Imp 2: AUC = 0.998, ROC saved.
   Imp 3: AUC = 0.998, ROC saved.
   Imp 4: AUC = 0.998, ROC saved.
   Imp 5: AUC = 0.998, ROC saved.
 Composite PS + AUC saved for SUBCAT_Antipsychotica_atypisch
 Running PS estimation for SUBCAT_TCA
   Imp 1: AUC = 0.717, ROC saved.
   Imp 2: AUC = 0.700, ROC saved.
   Imp 3: AUC = 0.687, ROC saved.
   Imp 4: AUC = 0.700, ROC saved.
   Imp 5: AUC = 0.716, ROC saved.
 Composite PS + AUC saved for SUBCAT_TCA
 Running PS estimation for SUBCAT_SSRI
   Imp 1: AUC = 0.763, ROC saved.
   Imp 2: AUC = 0.768, ROC saved.
   Imp 3: AUC = 0.759, ROC saved.
   Imp 4: AUC = 0.767, ROC saved.
   Imp 5: AUC = 0.769, ROC saved.
 Composite PS + AUC saved for SUBCAT_SSRI
 Running PS estimation for SUBCAT_SNRI
   Imp 1: AUC = 0.730, ROC saved.
   Imp 2: AUC = 0.721, ROC saved.
   Imp 3: AUC = 0.733, ROC saved.
   Imp 4: AUC = 0.747, ROC saved.
   Imp 5: AUC = 0.738, ROC saved.
 C

In [54]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X, T)

                # Get feature importance (absolute coefficients)
                importances = np.abs(model.coef_[0])
                importance_dict = dict(zip(X.columns, importances))
                df_feat = pd.DataFrame.from_dict(importance_dict, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)


 Computing feature importance for SUBCAT_Antipsychotica_atypisch
 Saved feature importance plot and CSV for SUBCAT_Antipsychotica_atypisch

 Computing feature importance for SUBCAT_TCA
 Saved feature importance plot and CSV for SUBCAT_TCA

 Computing feature importance for SUBCAT_SSRI
 Saved feature importance plot and CSV for SUBCAT_SSRI

 Computing feature importance for SUBCAT_SNRI
 Saved feature importance plot and CSV for SUBCAT_SNRI

 Computing feature importance for SUBCAT_Tetracyclische_antidepressiva
 Saved feature importance plot and CSV for SUBCAT_Tetracyclische_antidepressiva

 Computing feature importance for SUBCAT_Antidepressiva_overige
 Saved feature importance plot and CSV for SUBCAT_Antidepressiva_overige

 Computing feature importance for SUBCAT_Systemische_antihistaminica
 Saved feature importance plot and CSV for SUBCAT_Systemische_antihistaminica

 Computing feature importance for SUBCAT_anxiolytica_Benzodiazepine
 Saved feature importance plot and CSV for SUBCAT

In [55]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)


🔍 Processing IPTW + trimming + clipping for SUBCAT_Antipsychotica_atypisch
✅ Saved IPTW weights for SUBCAT_Antipsychotica_atypisch
    ℹ️ Retained 232/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCAT_Antipsychotica_atypisch/trimmed_data_imp1.*
    ℹ️ Retained 233/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCAT_Antipsychotica_atypisch/trimmed_data_imp2.*
    ℹ️ Retained 232/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCAT_Antipsychotica_atypisch/trimmed_data_imp3.*
    ℹ️ Retained 230/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCAT_Antipsychotica_atypisch/trimmed_data_imp4.*
    ℹ️ Retained 238/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCAT_Antipsychotica_atypisch/trimmed_data_imp5.*

🔍 Processing IPTW + trimming + clipping for SUBCAT_TCA
✅ Saved IPTW weights for SUBCAT_TCA
    ℹ️ Retained 356/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SUBCA

In [56]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)


📊 Plotting PS overlap for SUBCAT_Antipsychotica_atypisch
✅ Saved unweighted and weighted PS plots for SUBCAT_Antipsychotica_atypisch

📊 Plotting PS overlap for SUBCAT_TCA
✅ Saved unweighted and weighted PS plots for SUBCAT_TCA

📊 Plotting PS overlap for SUBCAT_SSRI
✅ Saved unweighted and weighted PS plots for SUBCAT_SSRI

📊 Plotting PS overlap for SUBCAT_SNRI
✅ Saved unweighted and weighted PS plots for SUBCAT_SNRI

📊 Plotting PS overlap for SUBCAT_Tetracyclische_antidepressiva
✅ Saved unweighted and weighted PS plots for SUBCAT_Tetracyclische_antidepressiva

📊 Plotting PS overlap for SUBCAT_Antidepressiva_overige
✅ Saved unweighted and weighted PS plots for SUBCAT_Antidepressiva_overige

📊 Plotting PS overlap for SUBCAT_Systemische_antihistaminica
✅ Saved unweighted and weighted PS plots for SUBCAT_Systemische_antihistaminica

📊 Plotting PS overlap for SUBCAT_anxiolytica_Benzodiazepine
✅ Saved unweighted and weighted PS plots for SUBCAT_anxiolytica_Benzodiazepine

📊 Plotting PS overl

In [57]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]


# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

 ✅ Saved: outputs\SUBCAT_Antipsychotica_atypisch\four_panel_overlap_SUBCAT_Antipsychotica_atypisch.png
 ✅ Saved: outputs\SUBCAT_TCA\four_panel_overlap_SUBCAT_TCA.png
 ✅ Saved: outputs\SUBCAT_SSRI\four_panel_overlap_SUBCAT_SSRI.png
 ✅ Saved: outputs\SUBCAT_SNRI\four_panel_overlap_SUBCAT_SNRI.png
 ✅ Saved: outputs\SUBCAT_Tetracyclische_antidepressiva\four_panel_overlap_SUBCAT_Tetracyclische_antidepressiva.png
 ✅ Saved: outputs\SUBCAT_Antidepressiva_overige\four_panel_overlap_SUBCAT_Antidepressiva_overige.png
 ✅ Saved: outputs\SUBCAT_Systemische_antihistaminica\four_panel_overlap_SUBCAT_Systemische_antihistaminica.png
 ✅ Saved: outputs\SUBCAT_anxiolytica_Benzodiazepine\four_panel_overlap_SUBCAT_anxiolytica_Benzodiazepine.png
 ✅ Saved: outputs\SUBCAT_hypnotica_Benzodiazepine\four_panel_overlap_SUBCAT_hypnotica_Benzodiazepine.png
 ✅ Saved: outputs\SUBCAT_Amfetaminen\four_panel_overlap_SUBCAT_Amfetaminen.png
 ✅ Saved: outputs\SUBCAT_Systemische_betablokkers\four_panel_overlap_SUBCAT_Systemis

In [58]:
# ATT calculation:

In [59]:
# Weighted:

In [60]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit weighted OLS with robust standard errors
                    ols_model = sm.WLS(Y, X_ols, weights=W).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T, W)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_summary_subcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)



🚀 Running OLS for SUBCAT_Antipsychotica_atypisch
✅ SUBCAT_Antipsychotica_atypisch | Seed 1: ATT = 6.6376, SE = 8.4840, p = 0.47772
✅ SUBCAT_Antipsychotica_atypisch | Seed 2: ATT = 6.1659, SE = 8.4248, p = 0.50482
✅ SUBCAT_Antipsychotica_atypisch | Seed 3: ATT = 4.0529, SE = 5.1990, p = 0.47920
✅ SUBCAT_Antipsychotica_atypisch | Seed 4: ATT = 7.3070, SE = 7.5990, p = 0.39071
✅ SUBCAT_Antipsychotica_atypisch | Seed 5: ATT = 1.8224, SE = 4.3531, p = 0.69697
✅ SUBCAT_Antipsychotica_atypisch | Seed 6: ATT = 5.3180, SE = 7.5841, p = 0.52183
✅ SUBCAT_Antipsychotica_atypisch | Seed 7: ATT = 6.2190, SE = 7.7693, p = 0.46828
✅ SUBCAT_Antipsychotica_atypisch | Seed 8: ATT = 5.0271, SE = 8.4432, p = 0.58362
✅ SUBCAT_Antipsychotica_atypisch | Seed 9: ATT = 2.6389, SE = 7.1649, p = 0.73130
✅ SUBCAT_Antipsychotica_atypisch | Seed 10: ATT = 5.5002, SE = 7.8783, p = 0.52354
📊 Diagnostic plots saved for SUBCAT_Antipsychotica_atypisch
🏆 Best result for SUBCAT_Antipsychotica_atypisch → Seed 5 | SE = 4.35

In [61]:
# Unweighted:

In [62]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1 = treated[col].std()
            s0 = control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                #W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit OLS with robust standard errors (unweighted)
                    ols_model = sm.OLS(Y, X_ols).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_rubin_summary_subcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)



🚀 Running OLS for SUBCAT_Antipsychotica_atypisch
✅ SUBCAT_Antipsychotica_atypisch | Seed 1: ATT = 7.6283, SE = 7.2945, p = 0.35470
✅ SUBCAT_Antipsychotica_atypisch | Seed 2: ATT = 6.2406, SE = 7.8969, p = 0.47359
✅ SUBCAT_Antipsychotica_atypisch | Seed 3: ATT = 5.4737, SE = 5.9511, p = 0.40974
✅ SUBCAT_Antipsychotica_atypisch | Seed 4: ATT = 6.9820, SE = 5.8724, p = 0.30021
✅ SUBCAT_Antipsychotica_atypisch | Seed 5: ATT = 0.9987, SE = 5.7341, p = 0.87019
✅ SUBCAT_Antipsychotica_atypisch | Seed 6: ATT = 5.6674, SE = 6.2234, p = 0.41399
✅ SUBCAT_Antipsychotica_atypisch | Seed 7: ATT = 4.7487, SE = 7.0243, p = 0.53609
✅ SUBCAT_Antipsychotica_atypisch | Seed 8: ATT = 5.9543, SE = 6.7890, p = 0.42998
✅ SUBCAT_Antipsychotica_atypisch | Seed 9: ATT = 2.0668, SE = 7.6258, p = 0.79978
✅ SUBCAT_Antipsychotica_atypisch | Seed 10: ATT = 7.5569, SE = 7.5702, p = 0.37465
📊 Diagnostic plots saved for SUBCAT_Antipsychotica_atypisch
🏆 Best result for SUBCAT_Antipsychotica_atypisch → Seed 5 | SE = 5.73

In [63]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "ols_summary_subcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: ols_summary_subcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubCat saved")

✅ Final_ATT_Summary_SubCat saved


In [64]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("ols_att_barplot_subcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ ols_att_barplot_subcat saved")

✅ ols_att_barplot_subcat saved


In [65]:
# Love plot:

In [66]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")


🔍 Processing SUBCAT_Amfetaminen...
📊 Exported numeric summary to: outputs\SUBCAT_Amfetaminen\covariate_balance_table_SUBCAT_Amfetaminen.xlsx
✅ Saved love plot: outputs\SUBCAT_Amfetaminen\love_plot_SUBCAT_Amfetaminen.pdf
📏 Max weighted SMD for SUBCAT_Amfetaminen: 0.635

🔍 Processing SUBCAT_Antidepressiva_Overige...
📊 Exported numeric summary to: outputs\SUBCAT_Antidepressiva_Overige\covariate_balance_table_SUBCAT_Antidepressiva_Overige.xlsx
✅ Saved love plot: outputs\SUBCAT_Antidepressiva_Overige\love_plot_SUBCAT_Antidepressiva_Overige.pdf
📏 Max weighted SMD for SUBCAT_Antidepressiva_Overige: 0.798

🔍 Processing SUBCAT_Antipsychotica_Atypisch...
📊 Exported numeric summary to: outputs\SUBCAT_Antipsychotica_Atypisch\covariate_balance_table_SUBCAT_Antipsychotica_Atypisch.xlsx
✅ Saved love plot: outputs\SUBCAT_Antipsychotica_Atypisch\love_plot_SUBCAT_Antipsychotica_Atypisch.pdf
📏 Max weighted SMD for SUBCAT_Antipsychotica_Atypisch: 0.592

🔍 Processing SUBCAT_Anti_Epileptica_Stemmingsstabil

In [67]:
# Heatmap:

In [68]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")


✅ Heatmap saved: outputs\SUBCAT_Antipsychotica_atypisch\heatmap_smd_SUBCAT_Antipsychotica_atypisch.png

✅ Heatmap saved: outputs\SUBCAT_TCA\heatmap_smd_SUBCAT_TCA.png

✅ Heatmap saved: outputs\SUBCAT_SSRI\heatmap_smd_SUBCAT_SSRI.png

✅ Heatmap saved: outputs\SUBCAT_SNRI\heatmap_smd_SUBCAT_SNRI.png

✅ Heatmap saved: outputs\SUBCAT_Tetracyclische_antidepressiva\heatmap_smd_SUBCAT_Tetracyclische_antidepressiva.png

✅ Heatmap saved: outputs\SUBCAT_Antidepressiva_overige\heatmap_smd_SUBCAT_Antidepressiva_overige.png

✅ Heatmap saved: outputs\SUBCAT_Systemische_antihistaminica\heatmap_smd_SUBCAT_Systemische_antihistaminica.png

✅ Heatmap saved: outputs\SUBCAT_anxiolytica_Benzodiazepine\heatmap_smd_SUBCAT_anxiolytica_Benzodiazepine.png

✅ Heatmap saved: outputs\SUBCAT_hypnotica_Benzodiazepine\heatmap_smd_SUBCAT_hypnotica_Benzodiazepine.png

✅ Heatmap saved: outputs\SUBCAT_Amfetaminen\heatmap_smd_SUBCAT_Amfetaminen.png

✅ Heatmap saved: outputs\SUBCAT_Systemische_betablokkers\heatmap_smd_SUBC

### SubSubCat Analysis:

In [69]:
covariates_SubSubCat_Oxazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Diazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Paracetamol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Lorazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Mirtazapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Escitalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Sertraline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Temazepam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Citalopram = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Quetiapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]




covariates_SubSubCat_Amitriptyline = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Venlafaxine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Fluoxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Topiramaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Tramadol = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica', 'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]



covariates_SubSubCat_Zopiclon = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Loprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Alprazolam = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_promethazine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Paroxetine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Bupropion = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Methylfenidaat = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD',
    'DIAGNOSIS_SEXUAL_TRAUMA', 'DIAGNOSIS_SUICIDALITY',
    'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

covariates_SubSubCat_Olanzapine = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]


covariates_SubSubCat_Zolpidem = [
    'treatmentdurationdays', 'CAPS5score_baseline',
    'CAT_Antidepressiva', 'CAT_Antipsychotica',
    'CAT_Benzodiazepine',
    'DIAGNOSIS_CHILDHOOD_TRAUMA', 'DIAGNOSIS_CPTSD', 'DIAGNOSIS_SEXUAL_TRAUMA',
    'DIAGNOSIS_SUICIDALITY', 'age',
    'DIAGNOSIS_ANXIETY_OCD',
    'DIAGNOSIS_PSYCHOTIC',
    'DIAGNOSIS_EATING_DISORDER',
    'DIAGNOSIS_SUBSTANCE_DISORDER', 'SDV_SEXE_1', 'SDV_SEXE_2', 'SDV_SEXE_3', "EB_TRAUMA_FOCUSED_THERAPY",
    "Bipolar_and_Mood_disorder", "ethnicity_Dutch", "ethnicity_other"
]

In [70]:
from collections import defaultdict

# This finds all variables that start with covariates_SUbSubCAT_ or covariates_SubSubcat_
final_covariates_map = defaultdict(list)
final_covariates_map.update({
    var.replace("covariates_", ""): val
    for var, val in globals().items()
    if var.lower().startswith("covariates_subsubcat_") and isinstance(val, list)
})

# Show detected group names
print("Groups found:", list(final_covariates_map.keys()))
medication_groups = list(final_covariates_map.keys())
print(medication_groups)

Groups found: ['SubSubCat_Oxazepam', 'SubSubCat_Diazepam', 'SubSubCat_Paracetamol', 'SubSubCat_Lorazepam', 'SubSubCat_Mirtazapine', 'SubSubCat_Escitalopram', 'SubSubCat_Sertraline', 'SubSubCat_Temazepam', 'SubSubCat_Citalopram', 'SubSubCat_Quetiapine', 'SubSubCat_Amitriptyline', 'SubSubCat_Venlafaxine', 'SubSubCat_Fluoxetine', 'SubSubCat_Topiramaat', 'SubSubCat_Tramadol', 'SubSubCat_Zopiclon', 'SubSubCat_Loprazolam', 'SubSubCat_Alprazolam', 'SubSubCat_promethazine', 'SubSubCat_Paroxetine', 'SubSubCat_Bupropion', 'SubSubCat_Methylfenidaat', 'SubSubCat_Olanzapine', 'SubSubCat_Zolpidem']
['SubSubCat_Oxazepam', 'SubSubCat_Diazepam', 'SubSubCat_Paracetamol', 'SubSubCat_Lorazepam', 'SubSubCat_Mirtazapine', 'SubSubCat_Escitalopram', 'SubSubCat_Sertraline', 'SubSubCat_Temazepam', 'SubSubCat_Citalopram', 'SubSubCat_Quetiapine', 'SubSubCat_Amitriptyline', 'SubSubCat_Venlafaxine', 'SubSubCat_Fluoxetine', 'SubSubCat_Topiramaat', 'SubSubCat_Tramadol', 'SubSubCat_Zopiclon', 'SubSubCat_Loprazolam', '

In [71]:
import os


def run_all_SUBSUBCAT_group_models(imputed_dfs):
    """
    Runs downstream analysis for each SUBSUBCAT medisubsubcation group using imputed datasets.

    Parameters:
    - imputed_dfs: list of 5 imputed DataFrames (from df_imputed_final_imp1.pkl ... imp5.pkl)
    
    Notes:
    - Covariate lists must be defined as global variables: covariates_subsubcat_<group>
    - Outputs are saved in: outputs/SUBSUBCAT_<GROUP>/
    """

    print(" Starting analysis for all SUBSUBCAT groups")

    for var_name in globals():
        if var_name.lower().startswith("covariates_subsubcat_") and isinstance(globals()[var_name], list):
            group_name = var_name.replace("covariates_", "")
            group_name = group_name.replace("_", " ").title().replace(" ", "_")  # e.g., subsubcat_z_drugs → Subsubcat_Z_Drugs
            group_name = group_name.replace("Subsubcat_", "SUBSUBCAT_")  # force prefix to uppercase

            covariates = globals()[var_name]
            output_dir = f"outputs/{group_name}"
            os.makedirs(output_dir, exist_ok=True)

            print(f"\n Processing {group_name}...")

            for k, df_imp in enumerate(imputed_dfs):
                print(f"  → Using imputation {k+1}")

                # Define X and Y
                X = df_imp[covariates]
                Y = df_imp["caps5_change_baseline"]

                # === Save X and Y as placeholder (replace with modeling later)
                X.to_csv(f"{output_dir}/X_imp{k+1}.csv", index=False)
                Y.to_frame(name="Y").to_csv(f"{output_dir}/Y_imp{k+1}.csv", index=False)

            print(f" Done: {group_name}")

    print("\n All SUBSUBCAT group analyses complete.")

# ========= STEP 4: Execute ========= #
run_all_SUBSUBCAT_group_models(imputed_dfs)

 Starting analysis for all SUBSUBCAT groups

 Processing SUBSUBCAT_Oxazepam...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBSUBCAT_Oxazepam

 Processing SUBSUBCAT_Diazepam...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBSUBCAT_Diazepam

 Processing SUBSUBCAT_Paracetamol...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBSUBCAT_Paracetamol

 Processing SUBSUBCAT_Lorazepam...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBSUBCAT_Lorazepam

 Processing SUBSUBCAT_Mirtazapine...
  → Using imputation 1
  → Using imputation 2
  → Using imputation 3
  → Using imputation 4
  → Using imputation 5
 Done: SUBSUBCAT_Mirtazapine

 Processing SUBSUBCAT_Escitalopram...
  → Using imputation 1


In [72]:
for treatment_var in medication_groups:
    print(f"\n {treatment_var}")
    
    for i, df in enumerate(imputed_dfs):
        if treatment_var not in df.columns:
            print(f"  Imp {i+1}:  Not found in columns.")
            continue

        treated = (df[treatment_var] == 1).sum()
        control = (df[treatment_var] == 0).sum()
        missing = df[treatment_var].isna().sum()

        print(f"  Imp {i+1}: Treated = {treated}, Control = {control}, Missing = {missing}")


 SubSubCat_Oxazepam
  Imp 1: Treated = 291, Control = 5834, Missing = 0
  Imp 2: Treated = 291, Control = 5834, Missing = 0
  Imp 3: Treated = 291, Control = 5834, Missing = 0
  Imp 4: Treated = 291, Control = 5834, Missing = 0
  Imp 5: Treated = 291, Control = 5834, Missing = 0

 SubSubCat_Diazepam
  Imp 1: Treated = 55, Control = 6070, Missing = 0
  Imp 2: Treated = 55, Control = 6070, Missing = 0
  Imp 3: Treated = 55, Control = 6070, Missing = 0
  Imp 4: Treated = 55, Control = 6070, Missing = 0
  Imp 5: Treated = 55, Control = 6070, Missing = 0

 SubSubCat_Paracetamol
  Imp 1: Treated = 67, Control = 6058, Missing = 0
  Imp 2: Treated = 67, Control = 6058, Missing = 0
  Imp 3: Treated = 67, Control = 6058, Missing = 0
  Imp 4: Treated = 67, Control = 6058, Missing = 0
  Imp 5: Treated = 67, Control = 6058, Missing = 0

 SubSubCat_Lorazepam
  Imp 1: Treated = 95, Control = 6030, Missing = 0
  Imp 2: Treated = 95, Control = 6030, Missing = 0
  Imp 3: Treated = 95, Control = 6030, M

In [73]:
import os
import pandas as pd
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from collections import defaultdict

# ✅ VIF computation function
def compute_vif(X):
    X = sm.add_constant(X, has_constant='add')
    vif_df = pd.DataFrame()
    vif_df["variable"] = X.columns
    vif_df["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
    return vif_df

# ✅ Process each group
for group in medication_groups:
    print(f"\n🔍 Processing VIF for {group}")

    if group not in final_covariates_map:
        print(f" ⚠️ No covariates found for {group}. Skipping.")
        continue

    covariates = final_covariates_map[group]
    vif_list = []

    for i, df_imp in enumerate(imputed_dfs):
        try:
            X = df_imp[covariates].copy()
            vif_df = compute_vif(X)
            vif_df["imputation"] = i + 1
            vif_list.append(vif_df)
        except Exception as e:
            print(f"   ❌ Failed on imputation {i+1} for {group}: {e}")

    if vif_list:
        all_vif = pd.concat(vif_list)
        pooled_vif = all_vif.groupby("variable")["VIF"].mean().reset_index()
        pooled_vif = pooled_vif.sort_values(by="VIF", ascending=False)

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)
        pooled_vif.to_csv(os.path.join(output_folder, "pooled_vif.csv"), index=False)

        print(f" ✅ Saved: {output_folder}/pooled_vif.csv")
    else:
        print(f" ⚠️ Skipped {group}: No valid imputations.")


🔍 Processing VIF for SubSubCat_Oxazepam
 ✅ Saved: outputs\SubSubCat_Oxazepam/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Diazepam
 ✅ Saved: outputs\SubSubCat_Diazepam/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Paracetamol
 ✅ Saved: outputs\SubSubCat_Paracetamol/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Lorazepam
 ✅ Saved: outputs\SubSubCat_Lorazepam/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Mirtazapine
 ✅ Saved: outputs\SubSubCat_Mirtazapine/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Escitalopram
 ✅ Saved: outputs\SubSubCat_Escitalopram/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Sertraline
 ✅ Saved: outputs\SubSubCat_Sertraline/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Temazepam
 ✅ Saved: outputs\SubSubCat_Temazepam/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Citalopram
 ✅ Saved: outputs\SubSubCat_Citalopram/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Quetiapine
 ✅ Saved: outputs\SubSubCat_Quetiapine/pooled_vif.csv

🔍 Processing VIF for SubSubCat_Am

In [74]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# ---------- PS Estimation Function ----------
def run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f" Running PS estimation for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        covariates = final_covariates_map[group]
        ps_matrix = pd.DataFrame()
        auc_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not found in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop missing treatment rows
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            # Train-test split for ROC
            X_train, X_test, T_train, T_test = train_test_split(
                X, T, stratify=T, test_size=0.3, random_state=42
            )

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X_train, T_train)

                ps_scores = model.predict_proba(X)[:, 1]
                ps_matrix[f"ps_imp{i+1}"] = pd.Series(ps_scores, index=valid_idx)

                # ROC & AUC
                auc = roc_auc_score(T_test, model.predict_proba(X_test)[:, 1])
                auc_list.append(auc)

                fpr, tpr, _ = roc_curve(T_test, model.predict_proba(X_test)[:, 1])
                plt.figure()
                plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
                plt.plot([0, 1], [0, 1], 'k--')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"ROC Curve - {group} (Imp {i+1})")
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_folder, f"roc_curve_imp{i+1}.png"))
                plt.close()
                print(f"   Imp {i+1}: AUC = {auc:.3f}, ROC saved.")

            except Exception as e:
                print(f"   Error in {group} (imp {i+1}): {e}")

        # Save AUCs and Composite PS
        if not ps_matrix.empty:
            # Fill NaN rows (from dropped subjects in some imputations) with mean
            ps_matrix["composite_ps"] = ps_matrix.mean(axis=1)
            ps_matrix.to_excel(os.path.join(output_folder, "propensity_scores.xlsx"))

            auc_df = pd.DataFrame({
                "imputation": [f"imp{i+1}" for i in range(len(auc_list))],
                "AUC": auc_list
            })
            auc_df.loc[len(auc_df.index)] = ["mean", np.mean(auc_list) if auc_list else np.nan]
            auc_df.to_excel(os.path.join(output_folder, "auc_scores.xlsx"), index=False)

            print(f" Composite PS + AUC saved for {group}")
        else:
            print(f" No valid PS scores generated for {group}")

# ---------- Run ----------
run_logistic_ps_modeling(imputed_dfs, medication_groups, final_covariates_map)


 Running PS estimation for SubSubCat_Oxazepam
   Imp 1: AUC = 0.777, ROC saved.
   Imp 2: AUC = 0.768, ROC saved.
   Imp 3: AUC = 0.780, ROC saved.
   Imp 4: AUC = 0.778, ROC saved.
   Imp 5: AUC = 0.768, ROC saved.
 Composite PS + AUC saved for SubSubCat_Oxazepam
 Running PS estimation for SubSubCat_Diazepam
   Imp 1: AUC = 0.715, ROC saved.
   Imp 2: AUC = 0.720, ROC saved.
   Imp 3: AUC = 0.730, ROC saved.
   Imp 4: AUC = 0.716, ROC saved.
   Imp 5: AUC = 0.732, ROC saved.
 Composite PS + AUC saved for SubSubCat_Diazepam
 Running PS estimation for SubSubCat_Paracetamol
   Imp 1: AUC = 0.684, ROC saved.
   Imp 2: AUC = 0.755, ROC saved.
   Imp 3: AUC = 0.674, ROC saved.
   Imp 4: AUC = 0.707, ROC saved.
   Imp 5: AUC = 0.691, ROC saved.
 Composite PS + AUC saved for SubSubCat_Paracetamol
 Running PS estimation for SubSubCat_Lorazepam
   Imp 1: AUC = 0.779, ROC saved.
   Imp 2: AUC = 0.776, ROC saved.
   Imp 3: AUC = 0.793, ROC saved.
   Imp 4: AUC = 0.782, ROC saved.
   Imp 5: AUC = 

In [75]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

def compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n Computing feature importance for {group}")

        if group not in final_covariates_map:
            print(f" No covariates found for {group}. Skipping.")
            continue

        covariates = final_covariates_map[group]
        output_folder = os.path.join("outputs", group)
        os.makedirs(output_folder, exist_ok=True)

        importance_df_list = []

        for i, df_imp in enumerate(imputed_dfs):
            if group not in df_imp.columns:
                print(f" {group} not in imputed dataset {i+1}. Skipping.")
                continue

            X = df_imp[covariates].copy()
            T = df_imp[group]

            # Drop NaNs in treatment
            valid_idx = T.dropna().index
            X = X.loc[valid_idx]
            T = T.loc[valid_idx]

            try:
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X, T)

                # Get feature importance (absolute coefficients)
                importances = np.abs(model.coef_[0])
                importance_dict = dict(zip(X.columns, importances))
                df_feat = pd.DataFrame.from_dict(importance_dict, orient='index', columns=[f"imp{i+1}"])
                df_feat.index.name = 'feature'
                importance_df_list.append(df_feat)

            except Exception as e:
                print(f"   Error during modeling: {e}")

        if importance_df_list:
            # Combine and average
            all_feat = pd.concat(importance_df_list, axis=1).fillna(0)
            all_feat["mean_importance"] = all_feat.mean(axis=1)

            # Filter top 30 non-zero
            non_zero = all_feat[all_feat["mean_importance"] > 0]
            top30 = non_zero.sort_values(by="mean_importance", ascending=False).head(30)

            # Save to CSV
            top30.to_csv(os.path.join(output_folder, "feature_importance.csv"))

            # Plot
            plt.figure(figsize=(10, 8))
            plt.barh(top30.index[::-1], top30["mean_importance"][::-1])  # plot top → bottom
            plt.xlabel("Mean Gain Importance")
            plt.title(f"Top 30 Feature Importance - {group}")
            plt.tight_layout()
            plt.savefig(os.path.join(output_folder, "feature_importance_top30.png"))
            plt.close()

            print(f" Saved feature importance plot and CSV for {group}")
        else:
            print(f" No valid models for {group}")

#  Run
compute_feature_importance(imputed_dfs, medication_groups, final_covariates_map)



 Computing feature importance for SubSubCat_Oxazepam
 Saved feature importance plot and CSV for SubSubCat_Oxazepam

 Computing feature importance for SubSubCat_Diazepam
 Saved feature importance plot and CSV for SubSubCat_Diazepam

 Computing feature importance for SubSubCat_Paracetamol
 Saved feature importance plot and CSV for SubSubCat_Paracetamol

 Computing feature importance for SubSubCat_Lorazepam
 Saved feature importance plot and CSV for SubSubCat_Lorazepam

 Computing feature importance for SubSubCat_Mirtazapine
 Saved feature importance plot and CSV for SubSubCat_Mirtazapine

 Computing feature importance for SubSubCat_Escitalopram
 Saved feature importance plot and CSV for SubSubCat_Escitalopram

 Computing feature importance for SubSubCat_Sertraline
 Saved feature importance plot and CSV for SubSubCat_Sertraline

 Computing feature importance for SubSubCat_Temazepam
 Saved feature importance plot and CSV for SubSubCat_Temazepam

 Computing feature importance for SubSubCat

In [76]:
import os
import pandas as pd
import numpy as np


def compute_trimmed_clipped_iptw(ps_df, treatment, lower=0.05, upper=0.95, clip_max=10):
    weights = []
    keep_mask = (ps_df > lower) & (ps_df < upper)

    for i in range(ps_df.shape[1]):
        ps = ps_df.iloc[:, i].clip(lower=1e-6, upper=1 - 1e-6)  # avoid div by zero
        mask = keep_mask.iloc[:, i]
        w = pd.Series(np.nan, index=ps.index)

        w[mask & (treatment == 1)] = 1 / ps[mask & (treatment == 1)]
        w[mask & (treatment == 0)] = 1 / (1 - ps[mask & (treatment == 0)])
        w = w.clip(upper=clip_max)
        weights.append(w)

    return pd.concat(weights, axis=1)


def apply_rubins_rule_to_iptw(iptw_matrix):
    """
    Given an IPTW matrix (n rows × M imputations), return Rubin’s rule pooled mean, SD, SE.
    """
    M = iptw_matrix.shape[1]
    q_bar = iptw_matrix.mean(axis=1)
    u_bar = iptw_matrix.var(axis=1, ddof=1)
    B = iptw_matrix.apply(lambda x: x.mean(), axis=1).var(ddof=1)
    total_var = u_bar + (1 + 1/M) * B
    total_se = np.sqrt(total_var)
    return q_bar, u_bar.pow(0.5), total_se


def run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map):
    for group in medication_groups:
        print(f"\n🔍 Processing IPTW + trimming + clipping for {group}")
        output_folder = os.path.join("outputs", group)
        ps_path = os.path.join(output_folder, "propensity_scores.xlsx")

        if not os.path.exists(ps_path):
            print(f"⚠️ Missing PS file: {ps_path}. Skipping.")
            continue

        try:
            ps_all = pd.read_excel(ps_path, index_col=0)
            ps_cols = [col for col in ps_all.columns if col.startswith("ps_imp")]
            composite_index = ps_all.index

            # Get treatment from one imputed dataset
            T_full = None
            for df in imputed_dfs:
                if group in df.columns:
                    T_full = df.loc[composite_index, group]
                    break

            if T_full is None:
                print(f"❌ Treatment column {group} not found in any imputed dataset.")
                continue

            # Compute IPTW matrix (shape: n × M)
            iptw_matrix = compute_trimmed_clipped_iptw(ps_all[ps_cols], T_full)
            iptw_matrix.columns = [f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]

            # Apply Rubin’s Rule for mean, SD, SE
            iptw_matrix["iptw_mean"], iptw_matrix["iptw_sd"], iptw_matrix["iptw_se"] = apply_rubins_rule_to_iptw(
                iptw_matrix[[f"iptw_imp{i+1}" for i in range(iptw_matrix.shape[1])]]
            )

            # Save IPTW matrix separately
            iptw_matrix.to_excel(os.path.join(output_folder, "iptw_weights.xlsx"))
            print(f"✅ Saved IPTW weights for {group}")

            # Save trimmed & clipped imputed datasets with IPTW
            for i in range(5):
                df = imputed_dfs[i].copy()
                if group not in df.columns:
                    continue

                trimmed_idx = iptw_matrix.index.intersection(df.index)
                needed_cols = final_covariates_map[group] + [group, "caps5_change_baseline"]

                # Select only necessary columns
                df_trimmed = df.loc[trimmed_idx, needed_cols].copy()
                df_trimmed["iptw"] = iptw_matrix[f"iptw_imp{i+1}"].loc[trimmed_idx]

                # ✅ DROP rows with missing IPTW values
                before = len(df_trimmed)
                df_trimmed = df_trimmed.dropna(subset=["iptw"])
                after = len(df_trimmed)
                print(f"    ℹ️ Retained {after}/{before} rows after IPTW NaN drop.")

                # Save to .pkl
                df_trimmed.to_pickle(os.path.join(output_folder, f"trimmed_data_imp{i+1}.pkl"))
                print(f"  💾 Saved trimmed dataset: {output_folder}/trimmed_data_imp{i+1}.*")

        except Exception as e:
            print(f"❌ Error in {group}: {e}")


run_trim_clip_save_all(imputed_dfs, medication_groups, final_covariates_map)


🔍 Processing IPTW + trimming + clipping for SubSubCat_Oxazepam
✅ Saved IPTW weights for SubSubCat_Oxazepam
    ℹ️ Retained 1533/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Oxazepam/trimmed_data_imp1.*
    ℹ️ Retained 1535/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Oxazepam/trimmed_data_imp2.*
    ℹ️ Retained 1543/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Oxazepam/trimmed_data_imp3.*
    ℹ️ Retained 1523/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Oxazepam/trimmed_data_imp4.*
    ℹ️ Retained 1516/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Oxazepam/trimmed_data_imp5.*

🔍 Processing IPTW + trimming + clipping for SubSubCat_Diazepam
✅ Saved IPTW weights for SubSubCat_Diazepam
    ℹ️ Retained 133/6125 rows after IPTW NaN drop.
  💾 Saved trimmed dataset: outputs\SubSubCat_Diazepam/trimmed_data_imp1.*
    ℹ️ Retained 120/6125 rows

In [77]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_ps_overlap_all_groups(medication_groups):
    for group in medication_groups:
        print(f"\n📊 Plotting PS overlap for {group}")

        folder = os.path.join("outputs", group)
        ps_file = os.path.join(folder, "propensity_scores.xlsx")
        iptw_file = os.path.join(folder, "iptw_weights.xlsx")
        trimmed_file = os.path.join(folder, "trimmed_data_imp1.pkl")

        if not all(os.path.exists(f) for f in [ps_file, iptw_file, trimmed_file]):
            print(f"⚠️ Missing required files for {group}. Skipping.")
            continue

        try:
            ps_df = pd.read_excel(ps_file, index_col=0)
            iptw_df = pd.read_excel(iptw_file, index_col=0)
            trimmed_df = pd.read_pickle(trimmed_file)

            # Extract
            ps = ps_df["composite_ps"].reindex(trimmed_df.index)
            w = iptw_df["iptw_mean"].reindex(trimmed_df.index)
            T = trimmed_df[group]

            # Masks to remove NaNs
            treated_mask = (T == 1) & ps.notna() & w.notna()
            control_mask = (T == 0) & ps.notna() & w.notna()

            treated = ps[treated_mask]
            treated_w = w[treated_mask]

            control = ps[control_mask]
            control_w = w[control_mask]

            # === Unweighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Unweighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_unweighted.png"))
            plt.close()

            # === Weighted Plot ===
            plt.figure(figsize=(8, 5))
            plt.hist([treated, control], bins=25, weights=[treated_w, control_w], label=["Treated", "Control"], alpha=0.6)
            plt.title(f"Weighted PS Overlap - {group}")
            plt.xlabel("Composite Propensity Score")
            plt.ylabel("Weighted Count")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(folder, "ps_overlap_weighted.png"))
            plt.close()

            print(f"✅ Saved unweighted and weighted PS plots for {group}")

        except Exception as e:
            print(f"❌ Error processing {group}: {e}")

# 🔁 Run
plot_ps_overlap_all_groups(medication_groups)


📊 Plotting PS overlap for SubSubCat_Oxazepam
✅ Saved unweighted and weighted PS plots for SubSubCat_Oxazepam

📊 Plotting PS overlap for SubSubCat_Diazepam
✅ Saved unweighted and weighted PS plots for SubSubCat_Diazepam

📊 Plotting PS overlap for SubSubCat_Paracetamol
✅ Saved unweighted and weighted PS plots for SubSubCat_Paracetamol

📊 Plotting PS overlap for SubSubCat_Lorazepam
✅ Saved unweighted and weighted PS plots for SubSubCat_Lorazepam

📊 Plotting PS overlap for SubSubCat_Mirtazapine
✅ Saved unweighted and weighted PS plots for SubSubCat_Mirtazapine

📊 Plotting PS overlap for SubSubCat_Escitalopram
✅ Saved unweighted and weighted PS plots for SubSubCat_Escitalopram

📊 Plotting PS overlap for SubSubCat_Sertraline
✅ Saved unweighted and weighted PS plots for SubSubCat_Sertraline

📊 Plotting PS overlap for SubSubCat_Temazepam
✅ Saved unweighted and weighted PS plots for SubSubCat_Temazepam

📊 Plotting PS overlap for SubSubCat_Citalopram
✅ Saved unweighted and weighted PS plots for

In [78]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up base output folder
output_base = "outputs"
ps_file = "propensity_scores.xlsx"
iptw_file = "iptw_weights.xlsx"
trimmed_data_file = "trimmed_data_imp1.pkl"

# Collect all treatment group folders
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

# Generate 4-panel overlap plots
for group in groups:
    group_path = os.path.join(output_base, group)
    try:
        # Load trimmed treatment info
        trimmed_df = pd.read_pickle(os.path.join(group_path, trimmed_data_file))
        index = trimmed_df.index

        # Fix: case-insensitive match for treatment variable
        possible_cols = [col for col in trimmed_df.columns if col.upper() == group.upper()]
        if not possible_cols:
            print(f" Treatment variable {group} not found in {group}, skipping.")
            continue
        treatment_var = possible_cols[0]
        T = trimmed_df[treatment_var]

        # Load composite PS (aligned to trimmed_df index)
        ps_df = pd.read_excel(os.path.join(group_path, ps_file), index_col=0)
        if 'composite_ps' not in ps_df.columns:
            print(f" Composite column missing in {ps_file}, skipping {group}.")
            continue
        ps = ps_df.loc[index, 'composite_ps']

        # Load IPTW weights (aligned to trimmed_df index)
        weights_df = pd.read_excel(os.path.join(group_path, iptw_file), index_col=0)
        if 'iptw_mean' not in weights_df.columns:
            print(f" IPTW weight column missing in {iptw_file}, skipping {group}.")
            continue
        weights = weights_df.loc[index, 'iptw_mean']

        # Prepare 4 datasets
        raw_treated = ps[T == 1]
        raw_control = ps[T == 0]
        weighted_treated = (ps[T == 1], weights[T == 1])
        weighted_control = (ps[T == 0], weights[T == 0])

        # Create plot
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle(f"Propensity Score Distribution - {group}", fontsize=14)

        axs[0, 0].hist(raw_treated, bins=20, alpha=0.7, color='blue')
        axs[0, 0].set_title("Raw Treated")

        axs[0, 1].hist(raw_control, bins=20, alpha=0.7, color='green')
        axs[0, 1].set_title("Raw Control")

        axs[1, 0].hist(weighted_treated[0], bins=20, weights=weighted_treated[1], alpha=0.7, color='blue')
        axs[1, 0].set_title("Weighted Treated")

        axs[1, 1].hist(weighted_control[0], bins=20, weights=weighted_control[1], alpha=0.7, color='green')
        axs[1, 1].set_title("Weighted Control")

        for ax in axs.flat:
            ax.set_xlim(0, 1)
            ax.set_xlabel("Propensity Score")
            ax.set_ylabel("Count")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save figure
        plot_path = os.path.join(group_path, f"four_panel_overlap_{group}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f" ✅ Saved: {plot_path}")

    except Exception as e:
        print(f" ❌ Error in {group}: {e}")

 ✅ Saved: outputs\SubSubCat_Oxazepam\four_panel_overlap_SubSubCat_Oxazepam.png
 ✅ Saved: outputs\SubSubCat_Diazepam\four_panel_overlap_SubSubCat_Diazepam.png
 ✅ Saved: outputs\SubSubCat_Paracetamol\four_panel_overlap_SubSubCat_Paracetamol.png
 ✅ Saved: outputs\SubSubCat_Lorazepam\four_panel_overlap_SubSubCat_Lorazepam.png
 ✅ Saved: outputs\SubSubCat_Mirtazapine\four_panel_overlap_SubSubCat_Mirtazapine.png
 ✅ Saved: outputs\SubSubCat_Escitalopram\four_panel_overlap_SubSubCat_Escitalopram.png
 ✅ Saved: outputs\SubSubCat_Sertraline\four_panel_overlap_SubSubCat_Sertraline.png
 ✅ Saved: outputs\SubSubCat_Temazepam\four_panel_overlap_SubSubCat_Temazepam.png
 ✅ Saved: outputs\SubSubCat_Citalopram\four_panel_overlap_SubSubCat_Citalopram.png
 ✅ Saved: outputs\SubSubCat_Quetiapine\four_panel_overlap_SubSubCat_Quetiapine.png
 ✅ Saved: outputs\SubSubCat_Amitriptyline\four_panel_overlap_SubSubCat_Amitriptyline.png
 ✅ Saved: outputs\SubSubCat_Venlafaxine\four_panel_overlap_SubSubCat_Venlafaxine.png


In [79]:
# ATT calculation:

In [80]:
# Weighted:

In [81]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T, weights):
    treated = X[T == 1]
    control = X[T == 0]
    w_treated = weights[T == 1]
    w_control = weights[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = np.average(treated[col], weights=w_treated), np.average(control[col], weights=w_control)
            s1 = np.sqrt(np.average((treated[col] - m1) ** 2, weights=w_treated))
            s0 = np.sqrt(np.average((control[col] - m0) ** 2, weights=w_control))
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit weighted OLS with robust standard errors
                    ols_model = sm.WLS(Y, X_ols, weights=W).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T, W)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_rubin_summary_subsubcats.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)



🚀 Running OLS for SubSubCat_Oxazepam
✅ SubSubCat_Oxazepam | Seed 1: ATT = 0.0868, SE = 1.4032, p = 0.95366
✅ SubSubCat_Oxazepam | Seed 2: ATT = 0.1236, SE = 1.5250, p = 0.93928
✅ SubSubCat_Oxazepam | Seed 3: ATT = 0.6143, SE = 1.5027, p = 0.70363
✅ SubSubCat_Oxazepam | Seed 4: ATT = 0.1286, SE = 1.5918, p = 0.93947
✅ SubSubCat_Oxazepam | Seed 5: ATT = -0.3906, SE = 1.3402, p = 0.78518
✅ SubSubCat_Oxazepam | Seed 6: ATT = 0.6141, SE = 1.4396, p = 0.69165
✅ SubSubCat_Oxazepam | Seed 7: ATT = -0.0904, SE = 1.8898, p = 0.96414
✅ SubSubCat_Oxazepam | Seed 8: ATT = 0.4791, SE = 1.4898, p = 0.76388
✅ SubSubCat_Oxazepam | Seed 9: ATT = 0.7935, SE = 1.3179, p = 0.57959
✅ SubSubCat_Oxazepam | Seed 10: ATT = 0.9925, SE = 1.3674, p = 0.50811
📊 Diagnostic plots saved for SubSubCat_Oxazepam
🏆 Best result for SubSubCat_Oxazepam → Seed 9 | SE = 1.3179

🚀 Running OLS for SubSubCat_Diazepam
✅ SubSubCat_Diazepam | Seed 1: ATT = -10.2416, SE = 5.4303, p = 0.13236
✅ SubSubCat_Diazepam | Seed 2: ATT = -9.5

In [82]:
# Unweighted:

In [83]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold
import statsmodels.api as sm
from scipy.stats import t, probplot
import xgboost as xgb

# -----------------------------
# Configuration
# -----------------------------
seeds = list(range(1, 11))
imputations = 5
output_folder = "outputs"

# -----------------------------
# Diagnostic Plotting Function
# -----------------------------
def create_diagnostic_plots(residuals_data, fitted_data, group_name):
    """Create 4 diagnostic plots for model validation"""
    plots_dir = os.path.join(output_folder, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    # Flatten the collected data
    all_residuals = np.concatenate(residuals_data)
    all_fitted = np.concatenate(fitted_data)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Diagnostic Plots - {group_name}', fontsize=16, fontweight='bold')
    
    # 1. Residuals vs Fitted
    axes[0,0].scatter(all_fitted, all_residuals, alpha=0.6, s=20)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
    axes[0,0].set_xlabel('Fitted Values')
    axes[0,0].set_ylabel('Residuals')
    axes[0,0].set_title('Residuals vs Fitted')
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. QQ Plot
    probplot(all_residuals, dist="norm", plot=axes[0,1])
    axes[0,1].set_title('Q-Q Plot (Normal)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Residual Histogram
    axes[1,0].hist(all_residuals, bins=30, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    axes[1,0].set_xlabel('Residuals')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Residual Distribution')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. Scale-Location Plot
    sqrt_abs_residuals = np.sqrt(np.abs(all_residuals))
    axes[1,1].scatter(all_fitted, sqrt_abs_residuals, alpha=0.6, s=20)
    axes[1,1].set_xlabel('Fitted Values')
    axes[1,1].set_ylabel('√|Residuals|')
    axes[1,1].set_title('Scale-Location Plot')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    plot_filename = os.path.join(plots_dir, f'{group_name}_unweighted.png')
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"📊 Diagnostic plots saved for {group_name}")

# -----------------------------
# Rubin's Rule
# -----------------------------
def rubins_pool(estimates, ses):
    m = len(estimates)
    q_bar = np.mean(estimates)
    u_bar = np.mean(np.square(ses))
    b_m = np.var(estimates, ddof=1)
    total_var = u_bar + ((1 + 1/m) * b_m)
    total_se = np.sqrt(total_var)
    ci_lower = q_bar - 1.96 * total_se
    ci_upper = q_bar + 1.96 * total_se
    p_value = 2 * (1 - t.cdf(np.abs(q_bar / total_se), df=m-1))
    rounded_p = round(p_value, 5)
    formatted_p = "< 0.00001" if rounded_p <= 0.00001 else f"{rounded_p:.5f}"
    return q_bar, total_se, ci_lower, ci_upper, formatted_p

# -----------------------------
# SMD + Variance Ratio
# -----------------------------
def calculate_smd_vr(X, T):
    treated = X[T == 1]
    control = X[T == 0]
    smd, vr = [], []
    for col in X.columns:
        try:
            m1, m0 = treated[col].mean(), control[col].mean()
            s1 = treated[col].std()
            s0 = control[col].std()
            pooled_sd = np.sqrt((s1 ** 2 + s0 ** 2) / 2)
            smd.append((m1 - m0) / pooled_sd if pooled_sd > 0 else 0)
            vr.append(s1**2 / s0**2 if s0**2 > 0 else 0)
        except Exception:
            smd.append(np.nan)
            vr.append(np.nan)
    return np.nanmean(smd), np.nanmean(vr)

# -----------------------------
# OLS Main Loop
# -----------------------------
def run_dml_with_trimmed_data(final_covariates_map):
    att_results = []
    balance_results = []

    for group, covariates in final_covariates_map.items():
        print(f"\n🚀 Running OLS for {group}")
        group_dir = os.path.join(output_folder, group)
        os.makedirs(group_dir, exist_ok=True)

        best_result = None
        best_se = float("inf")
        
        # Initialize lists to collect residuals and fitted values for diagnostic plots
        group_residuals = []
        group_fitted = []

        for seed in seeds:
            # Set random seed for this iteration
            np.random.seed(seed)
            
            att_list, se_list, r2_list, rmse_list, smd_list, vr_list = [], [], [], [], [], []

            for imp in range(1, imputations + 1):
                file_path = os.path.join(group_dir, f"trimmed_data_imp{imp}.pkl")
                if not os.path.exists(file_path):
                    continue

                df = pd.read_pickle(file_path)
                if group not in df.columns or "iptw" not in df.columns:
                    continue

                # Add bootstrap sampling with seed-based randomization
                n_samples = len(df)
                bootstrap_idx = np.random.choice(n_samples, size=n_samples, replace=True)
                df_bootstrap = df.iloc[bootstrap_idx].reset_index(drop=True)

                X = df_bootstrap[covariates].copy()
                T = df_bootstrap[group]
                Y = df_bootstrap["caps5_change_baseline"]
                #W = df_bootstrap["iptw"]

                try:
                    # Create design matrix with treatment variable and covariates
                    X_ols = pd.concat([T, X], axis=1)
                    X_ols = sm.add_constant(X_ols)
                    
                    # Fit OLS with robust standard errors (unweighted)
                    ols_model = sm.OLS(Y, X_ols).fit(cov_type='HC1')
                    
                    # Extract treatment effect (coefficient of treatment variable)
                    att = ols_model.params[group]  # Treatment coefficient
                    se = ols_model.bse[group]  # Robust standard error for treatment
                    
                    att_list.append(att)
                    se_list.append(se)

                    # Calculate model fit statistics
                    Y_pred = ols_model.fittedvalues
                    residuals = ols_model.resid
                    rmse = mean_squared_error(Y, Y_pred, squared=False)
                    r2 = ols_model.rsquared
                    r2_list.append(r2)
                    rmse_list.append(rmse)
                    
                    # Collect residuals and fitted values for diagnostic plots
                    group_residuals.append(residuals.values)
                    group_fitted.append(Y_pred.values)

                    smd, vr = calculate_smd_vr(X, T)
                    smd_list.append(smd)
                    vr_list.append(vr)
                except Exception as e:
                    print(f"⚠️ Error in {group}, seed {seed}, imp {imp}: {e}")

            if att_list:
                att, se, ci_l, ci_u, p_val = rubins_pool(att_list, se_list)
                avg_r2 = np.mean(r2_list)
                avg_rmse = np.mean(rmse_list)
                avg_smd = np.mean(smd_list)
                avg_vr = np.mean(vr_list)

                balance_results.append({
                    "group": group, "seed": seed, "smd": avg_smd, "vr": avg_vr
                })

                if se < best_se:
                    best_se = se
                    best_result = {
                        "group": group, "seed": seed, "att": att, "se": se,
                        "ci_lower": ci_l, "ci_upper": ci_u, "p_value": p_val,
                        "r2": avg_r2, "rmse": avg_rmse
                    }

                print(f"✅ {group} | Seed {seed}: ATT = {att:.4f}, SE = {se:.4f}, p = {p_val}")
            else:
                print(f"⚠️ No valid results for {group} | Seed {seed}")

        # Create diagnostic plots for this group
        if group_residuals and group_fitted:
            create_diagnostic_plots(group_residuals, group_fitted, group)

        if best_result:
            att_results.append(best_result)
            print(f"🏆 Best result for {group} → Seed {best_result['seed']} | SE = {best_result['se']:.4f}")

    # Save final output
    pd.DataFrame(att_results).to_excel("ols_rubin_summary_subsubcats_unweighted.xlsx", index=False)
    pd.DataFrame(balance_results).to_excel("smd_vr_summary_subsubcats_unweighted.xlsx", index=False)
    print("\n🎯 All summary files saved.")

run_dml_with_trimmed_data(final_covariates_map)



🚀 Running OLS for SubSubCat_Oxazepam
✅ SubSubCat_Oxazepam | Seed 1: ATT = -0.0399, SE = 1.3432, p = 0.97774
✅ SubSubCat_Oxazepam | Seed 2: ATT = 0.0996, SE = 1.4561, p = 0.94877
✅ SubSubCat_Oxazepam | Seed 3: ATT = 0.7378, SE = 1.6199, p = 0.67240
✅ SubSubCat_Oxazepam | Seed 4: ATT = 0.1610, SE = 1.7001, p = 0.92911
✅ SubSubCat_Oxazepam | Seed 5: ATT = -0.8644, SE = 1.4484, p = 0.58281
✅ SubSubCat_Oxazepam | Seed 6: ATT = 0.5571, SE = 1.4591, p = 0.72202
✅ SubSubCat_Oxazepam | Seed 7: ATT = -0.3440, SE = 1.8962, p = 0.86487
✅ SubSubCat_Oxazepam | Seed 8: ATT = 0.5488, SE = 1.6198, p = 0.75178
✅ SubSubCat_Oxazepam | Seed 9: ATT = 0.7892, SE = 1.2741, p = 0.56918
✅ SubSubCat_Oxazepam | Seed 10: ATT = 0.8554, SE = 1.3495, p = 0.56062
📊 Diagnostic plots saved for SubSubCat_Oxazepam
🏆 Best result for SubSubCat_Oxazepam → Seed 9 | SE = 1.2741

🚀 Running OLS for SubSubCat_Diazepam
✅ SubSubCat_Diazepam | Seed 1: ATT = -10.6111, SE = 5.1358, p = 0.10771
✅ SubSubCat_Diazepam | Seed 2: ATT = -9.

In [85]:
import os
import pandas as pd
import numpy as np
from scipy.stats import sem, ttest_ind

# ----------------------------------
# File paths
# ----------------------------------
output_base = "outputs"
att_file = "ols_rubin_summary_subsubcats.xlsx"
trimmed_file = "trimmed_data_imp1.pkl"
auc_file = "auc_scores.xlsx"  # NEW

# ----------------------------------
# Load ATT Summary
# ----------------------------------
if os.path.exists(att_file):
    att_df = pd.read_excel(att_file)
else:
    raise FileNotFoundError("❌ ATT summary file not found: ols_summary_subsubcats.xlsx")

summary_rows = []

# ----------------------------------
# Loop over medication groups
# ----------------------------------
groups = [g for g in medication_groups if os.path.isdir(os.path.join(output_base, g))]

for med in groups:
    try:
        group_path = os.path.join(output_base, med)

        # Load trimmed data
        df = pd.read_pickle(os.path.join(group_path, trimmed_file))

        # Detect treatment column
        treatment_cols = [col for col in df.columns if col.upper() == med.upper()]
        if not treatment_cols:
            print(f"⚠️ Treatment column {med} not found in trimmed data. Skipping.")
            continue
        treatment_var = treatment_cols[0]

        # Extract treatment and outcome
        T = df[treatment_var]
        Y = df["caps5_change_baseline"]

        # Treated and control stats
        treated = Y[T == 1]
        control = Y[T == 0]

        mean_treat = treated.mean()
        se_treat = sem(treated) if len(treated) > 1 else np.nan

        mean_ctrl = control.mean()
        se_ctrl = sem(control) if len(control) > 1 else np.nan

        # Cohen's d (unadjusted)
        pooled_sd = np.sqrt(((treated.std() ** 2) + (control.std() ** 2)) / 2)
        cohen_d = (mean_treat - mean_ctrl) / pooled_sd if pooled_sd > 0 else np.nan

        # E-value (unadjusted)
        delta = mean_treat - mean_ctrl
        E = delta / abs(mean_ctrl) * 100 if mean_ctrl != 0 else np.nan

        # Unadjusted p-value
        try:
            t_stat, p_val = ttest_ind(treated, control, equal_var=False, nan_policy="omit")
            rounded_p = round(p_val, 5)
            formatted_p = "< 0.00001" if rounded_p < 0.00001 else rounded_p
        except Exception:
            formatted_p = np.nan

        # AUC from new auc_scores.xlsx file
        auc_val = np.nan
        auc_path = os.path.join(group_path, auc_file)
        if os.path.exists(auc_path):
            auc_df = pd.read_excel(auc_path)
            if "AUC" in auc_df.columns:
                auc_val = auc_df["AUC"].dropna().mean()

        # Adjusted stats from Rubin summary
        att_row = att_df[att_df["group"].str.strip().str.upper() == med.strip().upper()]
        if not att_row.empty:
            att = att_row.iloc[0]["att"]
            att_se = att_row.iloc[0]["se"]
            att_p_val = att_row.iloc[0]["p_value"]
            r2 = att_row.iloc[0]["r2"]
            rmse = att_row.iloc[0]["rmse"]

            try:
                rounded_att_p = round(float(att_p_val), 5)
                formatted_att_p = "< 0.00001" if rounded_att_p < 0.00001 else rounded_att_p
            except:
                formatted_att_p = att_p_val
        else:
            att, att_se, formatted_att_p, r2, rmse = np.nan, np.nan, np.nan, np.nan, np.nan

        # Append full row
        summary_rows.append({
            'Medication Group': med,
            'Mean Treated': mean_treat,
            'SE Treated': se_treat,
            'Mean Control': mean_ctrl,
            'SE Control': se_ctrl,
            'Cohen d': cohen_d,
            'E (Unadjusted)': E,
            'n Treated': len(treated),
            'n Control': len(control),
            #'Unadjusted p-value': formatted_p,
            'ATT Estimate': att,
            'ATT SE (Robust)': att_se,
            'ATT p-value': formatted_att_p,
            'R²': r2,
            'RMSE': rmse,
            'AUC': auc_val
        })

    except Exception as e:
        print(f"❌ Error in {med}: {e}")

# ----------------------------------
# Save final summary
# ----------------------------------
summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("Medication Group")
summary_df.to_excel("Final_ATT_Summary_SubSubCat.xlsx", index=False)
print("✅ Final_ATT_Summary_SubSubCat saved")

✅ Final_ATT_Summary_SubSubCat saved


In [86]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ✅ Load the final summary table
final_df = pd.read_excel("Final_ATT_Summary_SubSubCat.xlsx")

# ✅ Parse DML p-values (handle "< 0.00001")
def parse_pval(p):
    try:
        if isinstance(p, str) and "<" in p:
            return 0.000001
        return float(p)
    except:
        return None

final_df['ATT p-value'] = final_df['ATT p-value'].apply(parse_pval)

# ✅ Plot settings
width = 0.35

# ✅ Plotting function for a single medication group
def plot_single_group(row):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars1 = ax.bar(-width/2, row['Mean Control'], width, 
                   yerr=row['SE Control'], label='Control', hatch='//', color='gray', capsize=5)
    bars2 = ax.bar(+width/2, row['Mean Treated'], width, 
                   yerr=row['SE Treated'], label='Treated', color='steelblue', capsize=5)

    label = (
        f"ATT = {row['ATT Estimate']:.2f}\n"
        f"d = {row['Cohen d']:.2f}, p = {row['ATT p-value']:.3f}\n"
        f"nT = {row['n Treated']}, nC = {row['n Control']}\n"
        f"E = {row['E (Unadjusted)']:.1f}%"
    )
    max_y = max(row['Mean Control'], row['Mean Treated']) + 1.5
    ax.text(0, max_y, label, ha='center', va='bottom', fontsize=9, color='#FFD700')

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks([-width/2, +width/2])
    ax.set_xticklabels(['Control', 'Treated'])
    ax.set_title(f"Group: {row['Medication Group']}", fontsize=12, weight='bold')
    ax.set_ylabel("CAPS5 Change Score")
    ax.set_ylim(bottom=0, top=max_y + 2)
    ax.legend()
    fig.tight_layout()
    return fig

# ✅ Generate and save all plots into a multi-page PDF
with PdfPages("ols_att_barplot_subsubcat.pdf") as pdf:
    for idx, row in final_df.iterrows():
        fig = plot_single_group(row)
        pdf.savefig(fig, dpi=300, bbox_inches='tight')
        plt.close(fig)
print("✅ ols_att_barplot_subsubcat saved")

✅ ols_att_barplot_subsubcat saved


In [87]:
# Love plot:

In [88]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ----------------------------------------
# Functions to calculate balance
# ----------------------------------------
def calculate_smd(x1, x2, w1=None, w2=None):
    def weighted_mean(x, w): return np.average(x, weights=w)
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    m1 = weighted_mean(x1, w1) if w1 is not None else np.mean(x1)
    m2 = weighted_mean(x2, w2) if w2 is not None else np.mean(x2)
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    pooled_sd = np.sqrt((v1 + v2) / 2)
    return np.abs(m1 - m2) / pooled_sd if pooled_sd > 0 else 0

def variance_ratio(x1, x2, w1=None, w2=None):
    def weighted_var(x, w):
        m = np.average(x, weights=w)
        return np.average((x - m) ** 2, weights=w)
    
    v1 = weighted_var(x1, w1) if w1 is not None else np.var(x1, ddof=1)
    v2 = weighted_var(x2, w2) if w2 is not None else np.var(x2, ddof=1)
    
    return max(v1 / v2, v2 / v1) if v1 > 0 and v2 > 0 else 1

# ----------------------------------------
# Setup
# ----------------------------------------
output_base = "outputs"
groups = [g for g in os.listdir(output_base) if os.path.isdir(os.path.join(output_base, g))]

# Create a case-insensitive mapping
final_covariates_map_lower = {k.lower(): v for k, v in final_covariates_map.items()}

# ----------------------------------------
# Main Loop
# ----------------------------------------
for group in groups:
    if group.lower() not in final_covariates_map_lower:
        continue

    print(f"\n🔍 Processing {group}...")

    try:
        group_path = os.path.join(output_base, group)
        covariates = final_covariates_map_lower[group.lower()]
        
        column_name = None
        for col in pd.read_pickle(os.path.join(group_path, "trimmed_data_imp1.pkl")).columns:
            if col.lower() == group.lower():
                column_name = col
                break
        if column_name is None:
            print(f"⚠️ Column not found for {group}, skipping.")
            continue

        smd_unw_all, smd_w_all = [], []
        vr_unw_all, vr_w_all = [], []

        for i in range(1, 6):
            df_path = os.path.join(group_path, f"trimmed_data_imp{i}.pkl")
            iptw_path = os.path.join(group_path, "iptw_weights.xlsx")

            if not os.path.exists(df_path) or not os.path.exists(iptw_path):
                print(f"⚠️ Missing data for {group} imp{i}, skipping.")
                continue

            df = pd.read_pickle(df_path)
            iptw_df = pd.read_excel(iptw_path, index_col=0)
            T = df[column_name]
            W = iptw_df.loc[df.index, "iptw_mean"]

            smd_unw_i, smd_w_i, vr_unw_i, vr_w_i = [], [], [], []

            for cov in covariates:
                x1, x0 = df.loc[T == 1, cov], df.loc[T == 0, cov]
                w1, w0 = W[T == 1], W[T == 0]

                su = calculate_smd(x1, x0)
                sw = calculate_smd(x1, x0, w1, w0)

                vu = variance_ratio(x1, x0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan
                vw = variance_ratio(x1, x0, w1, w0) if cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] else np.nan

                smd_unw_i.append(su)
                smd_w_i.append(sw)
                vr_unw_i.append(vu)
                vr_w_i.append(vw)

            smd_unw_all.append(smd_unw_i)
            smd_w_all.append(smd_w_i)
            vr_unw_all.append(vr_unw_i)
            vr_w_all.append(vr_w_i)

        smd_unw = np.mean(smd_unw_all, axis=0)
        smd_w = np.mean(smd_w_all, axis=0)
        vr_unw = np.nanmean(vr_unw_all, axis=0)
        vr_w = np.nanmean(vr_w_all, axis=0)

        severity = []
        for sw in smd_w:
            if sw <= 0.1:
                severity.append("Good")
            elif sw <= 0.2:
                severity.append("Moderate")
            else:
                severity.append("Poor")

        covariate_names = covariates
        numeric_df = pd.DataFrame({
            "Covariate": covariate_names,
            "SMD_Unweighted": smd_unw,
            "SMD_Weighted": smd_w,
            "Imbalance_Severity": severity,
            "VR_Unweighted": vr_unw,
            "VR_Weighted": vr_w
        })

        numeric_path = os.path.join(group_path, f"covariate_balance_table_{group}.xlsx")
        numeric_df.to_excel(numeric_path, index=False)
        print(f"📊 Exported numeric summary to: {numeric_path}")

        # -------------------------
        # Plot
        # -------------------------
        labels = covariates
        y_pos = np.arange(len(labels))

        fig, axes = plt.subplots(1, 2, figsize=(18, len(labels) * 0.45))

        axes[0].scatter(smd_unw, y_pos, color='red', label="Unweighted")
        axes[0].scatter(smd_w, y_pos, color='blue', label="Weighted")
        axes[0].axvline(0.1, color='gray', linestyle='--', label="Threshold 0.1")
        axes[0].axvline(0.2, color='black', linestyle='--', label="Threshold 0.2")
        axes[0].set_xlim(0, max(max(smd_unw), max(smd_w), 0.25) + 0.05)
        axes[0].set_yticks(y_pos)
        axes[0].set_yticklabels(labels)
        axes[0].invert_yaxis()
        axes[0].set_title("Standardized Mean Differences (SMD)")
        axes[0].legend(loc="upper right")
        axes[0].grid(True)

        vr_mask = [cov in ['treatmentdurationdays', 'CAPS5score_baseline', 'age'] for cov in covariates]
        filtered_y = [i for i, b in enumerate(vr_mask) if b]
        filtered_labels = [labels[i] for i in filtered_y]
        filtered_vr_unw = [vr_unw[i] for i in filtered_y]
        filtered_vr_w = [vr_w[i] for i in filtered_y]

        axes[1].scatter(filtered_vr_unw, filtered_y, color='blue', marker='o', label="Unweighted")
        axes[1].scatter(filtered_vr_w, filtered_y, color='red', marker='x', label="Weighted")
        axes[1].axvline(2, color='gray', linestyle='--')
        axes[1].axvline(0.5, color='gray', linestyle='--')
        axes[1].set_xlim(0, max(filtered_vr_unw + filtered_vr_w + [2.5]) + 0.5)
        axes[1].set_yticks(filtered_y)
        axes[1].set_yticklabels(filtered_labels)
        axes[1].invert_yaxis()
        axes[1].set_title("Variance Ratio (VR)")
        axes[1].legend()
        axes[1].grid(True)

        fig.suptitle(f"Covariate Balance for {group.replace('CAT_', '')}", fontsize=14, weight='bold')
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        plot_path = os.path.join(group_path, f"love_plot_{group}.pdf")
        fig.savefig(plot_path, dpi=300)
        plt.close()
        print(f"✅ Saved love plot: {plot_path}")
        print(f"📏 Max weighted SMD for {group}: {np.max(smd_w):.3f}")

    except Exception as e:
        print(f"❌ Error in {group}: {e}")



🔍 Processing SUBSUBCAT_Alprazolam...
📊 Exported numeric summary to: outputs\SUBSUBCAT_Alprazolam\covariate_balance_table_SUBSUBCAT_Alprazolam.xlsx
✅ Saved love plot: outputs\SUBSUBCAT_Alprazolam\love_plot_SUBSUBCAT_Alprazolam.pdf
📏 Max weighted SMD for SUBSUBCAT_Alprazolam: 0.736

🔍 Processing SUBSUBCAT_Amitriptyline...
📊 Exported numeric summary to: outputs\SUBSUBCAT_Amitriptyline\covariate_balance_table_SUBSUBCAT_Amitriptyline.xlsx
✅ Saved love plot: outputs\SUBSUBCAT_Amitriptyline\love_plot_SUBSUBCAT_Amitriptyline.pdf
📏 Max weighted SMD for SUBSUBCAT_Amitriptyline: 0.666

🔍 Processing SUBSUBCAT_Bupropion...
📊 Exported numeric summary to: outputs\SUBSUBCAT_Bupropion\covariate_balance_table_SUBSUBCAT_Bupropion.xlsx
✅ Saved love plot: outputs\SUBSUBCAT_Bupropion\love_plot_SUBSUBCAT_Bupropion.pdf
📏 Max weighted SMD for SUBSUBCAT_Bupropion: 0.719

🔍 Processing SUBSUBCAT_Citalopram...
📊 Exported numeric summary to: outputs\SUBSUBCAT_Citalopram\covariate_balance_table_SUBSUBCAT_Citalopram

In [89]:
# Heatmap:

In [90]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
#-----------------------------
# Generate heatmaps
# -------------------------------
for treatment_var in medication_groups:
    print(f"\n========== Creating Heatmap for {treatment_var} ==========")

    try:
        output_folder = os.path.join('outputs', treatment_var)
        balance_path = os.path.join(output_folder, f'covariate_balance_table_{treatment_var}.xlsx')

        if not os.path.exists(balance_path):
            print(f"❌ Balance file not found: {balance_path}")
            continue

        balance_df = pd.read_excel(balance_path)

        # ✅ Use finalized covariates + 'Propensity Score'
        covariates = final_covariates_map[treatment_var] + ['Propensity Score']
        balance_df = balance_df[balance_df['Covariate'].isin(covariates)]

        # ✅ Check for CAPS5score_baseline
        highlight_caps = 'CAPS5score_baseline' in balance_df['Covariate'].values

        # ✅ Format for heatmap
        heatmap_df = balance_df[['Covariate', 'SMD_Unweighted', 'SMD_Weighted']].copy()
        heatmap_df.columns = ['Covariate', 'Unweighted', 'Weighted']
        heatmap_df = heatmap_df.set_index('Covariate')
        heatmap_df = heatmap_df.sort_values(by='Unweighted', ascending=False)

        # ✅ Plot
        plt.figure(figsize=(12, max(10, len(heatmap_df) * 0.35)))
        ax = sns.heatmap(
            heatmap_df,
            cmap="coolwarm",
            annot=True,
            fmt=".2f",
            linewidths=0.6,
            linecolor='gray',
            cbar_kws={"label": "Standardized Mean Difference"}
        )

        plt.title(f"Covariate Balance Heatmap (Rubin IPTW)\n{treatment_var}", fontsize=15, weight='bold')
        plt.xlabel("Condition")
        plt.ylabel("Covariate")

        # ✅ Bold CAPS5score_baseline if present
        if highlight_caps:
            ylabels = [label.get_text() for label in ax.get_yticklabels()]
            ax.set_yticklabels([
                f"{label} ←" if label == 'CAPS5score_baseline' else label for label in ylabels
            ])

        plt.tight_layout()

        # ✅ Save image
        save_path = os.path.join(output_folder, f'heatmap_smd_{treatment_var}.png')
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"✅ Heatmap saved: {save_path}")

    except Exception as e:
        print(f"⚠️ Error processing {treatment_var}: {e}")


✅ Heatmap saved: outputs\SubSubCat_Oxazepam\heatmap_smd_SubSubCat_Oxazepam.png

✅ Heatmap saved: outputs\SubSubCat_Diazepam\heatmap_smd_SubSubCat_Diazepam.png

✅ Heatmap saved: outputs\SubSubCat_Paracetamol\heatmap_smd_SubSubCat_Paracetamol.png

✅ Heatmap saved: outputs\SubSubCat_Lorazepam\heatmap_smd_SubSubCat_Lorazepam.png

✅ Heatmap saved: outputs\SubSubCat_Mirtazapine\heatmap_smd_SubSubCat_Mirtazapine.png

✅ Heatmap saved: outputs\SubSubCat_Escitalopram\heatmap_smd_SubSubCat_Escitalopram.png

✅ Heatmap saved: outputs\SubSubCat_Sertraline\heatmap_smd_SubSubCat_Sertraline.png

✅ Heatmap saved: outputs\SubSubCat_Temazepam\heatmap_smd_SubSubCat_Temazepam.png

✅ Heatmap saved: outputs\SubSubCat_Citalopram\heatmap_smd_SubSubCat_Citalopram.png

✅ Heatmap saved: outputs\SubSubCat_Quetiapine\heatmap_smd_SubSubCat_Quetiapine.png

✅ Heatmap saved: outputs\SubSubCat_Amitriptyline\heatmap_smd_SubSubCat_Amitriptyline.png

✅ Heatmap saved: outputs\SubSubCat_Venlafaxine\heatmap_smd_SubSubCat_Venl