In [None]:
import sys, site

# Make sure the user site-packages dir (/home/jupyter/.local/...) is on sys.path
try:
    user_site = site.getusersitepackages()
    if user_site not in sys.path:
        sys.path.append(user_site)
        print("Added user site-packages to sys.path:", user_site)
    else:
        print("User site-packages already on sys.path:", user_site)
except Exception as e:
    print("Could not resolve user site-packages:", e)


In [None]:
# --- CELL 1: SETUP ---
import sys
!{sys.executable} -m pip install pandas numpy scipy matplotlib seaborn tqdm scikit-learn lifelines pandas-gbq google-cloud-bigquery fastparquet

import pandas as pd
import numpy as np
import os
import gc
import datetime
import subprocess
import scipy.sparse as sp
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import warnings

# --- Machine Learning & Stats ---
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV, Lasso
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, StackingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, brier_score_loss
from sklearn.calibration import calibration_curve

import fastparquet

# --- Survival Analysis ---
import sys, subprocess, site

# ensure user site is on path **before** importing
try:
    user_site = site.getusersitepackages()
    if user_site not in sys.path:
        sys.path.append(user_site)
except Exception:
    pass

try:
    from lifelines import CoxPHFitter, KaplanMeierFitter
    from lifelines.statistics import logrank_test
except ImportError:
    # install into the current interpreter's env
    subprocess.check_call([sys.executable, "-m", "pip", "install", "lifelines"])
    from lifelines import CoxPHFitter, KaplanMeierFitter
    from lifelines.statistics import logrank_test


# --- Configurations ---
warnings.filterwarnings('ignore')
sns.set_style("whitegrid")
pd.set_option('display.max_columns', None)
pd.set_option('mode.chained_assignment', None)
WORKSPACE_CDR = os.environ.get("WORKSPACE_CDR", "")
N_CORES = max(1, os.cpu_count() - 2)

print(f"Environment Ready. CDR: {WORKSPACE_CDR}, Cores: {N_CORES}")

# --- Helper Functions ---

def to_naive_utc_day(series):
    """Robustly converts mixed timezones to naive UTC midnight."""
    return pd.to_datetime(series, errors='coerce', utc=True).dt.tz_localize(None).dt.normalize()

def clean_mem():
    """Forces garbage collection."""
    gc.collect()

def calculate_ess(weights):
    """Calculates Kish's Effective Sample Size."""
    if len(weights) == 0: return 0
    return (np.sum(weights) ** 2) / np.sum(weights ** 2)

def sparse_weighted_mean_var(X, weights):
    """Calculates means/vars of sparse matrix X with weights without densifying."""
    # X is (N, P), weights is (N,)
    W = sp.diags(weights)
    X_weighted = W @ X
    sum_w = np.sum(weights)
    means = np.array(X_weighted.sum(axis=0) / sum_w).flatten()
    
    # Variance is trickier, simplified approximation for Love Plot speed:
    # Var = E[X^2] - (E[X])^2
    X2 = X.power(2)
    means2 = np.array((W @ X2).sum(axis=0) / sum_w).flatten()
    vars_ = means2 - (means ** 2)
    return means, vars_

In [None]:
# =============================================================================
# CELL 2: Clinical Definitions & Negative Controls (Refined)
# =============================================================================
print("\n--- CELL 2: Clinical Definitions & Negative Controls ---")

# --- 1. EXPOSURES ---
CONTRAST_CT = {4139745, 21492176, 4335400, 3047782, 4327032, 3013610, 36713226, 3053128, 4252907, 3019625}
CONTRAST_MRI = {4335399, 4161393, 4202274, 4197203, 36717294, 45765683, 37117806, 37109194, 37109196}
CONTRAST_IDS = CONTRAST_CT.union(CONTRAST_MRI)

NON_CONTRAST_CT = {37109313, 3049940, 37117305, 3047921, 36713200, 3018999, 40771605, 36713202, 3035568}
NON_CONTRAST_MRI = {37109312, 36713204, 36713045, 36713262, 3024397, 36713243, 3053040, 37109329, 42535581, 42535582}
NON_CONTRAST_IDS = NON_CONTRAST_CT.union(NON_CONTRAST_MRI)

ALL_IMAGING = CONTRAST_IDS.union(NON_CONTRAST_IDS)

# --- 2. MAIN OUTCOMES (The Targets) ---
# Naming Convention: Key is used for column naming (e.g., date_AKI_30)
ANALYSIS_OUTCOMES = {
    'AKI_30': ({761083, 197320, 40481064, 4328366, 37116432, 45757442, 37016366}, 30),
    'NEW_DIALYSIS_90': ({4032243, 4146536, 4324124, 4019967, 40482357}, 90),
    'MORTALITY_30': ('DEATH', 30),
    'MAE_30': ('COMPOSITE', 30), # Defined as min(AKI, Dialysis, Death)
    'THYROID_90': ({138384, 37016342, 45757058, 4032331}, 90)
}

# --- 3. COVARIATE DEFINITIONS (Not Policies) ---
# Used for confounding adjustment, NOT for "withholding" rules
THYROTOXICOSIS_IDS = {37016342, 45757058, 440936, 134438} 

# Lab Concepts - STRICT SEPARATION
# Only use true eGFR codes for eGFR. Creatinine is a separate covariate.

# eGFR (CKD-EPI, MDRD, etc.) – ONLY true eGFR concepts
EGFR_CONCEPTS = {
    333096, 3049187, 3053283, 3029859, 1619026, 1619025
}

# Serum Creatinine
CREATININE_CONCEPTS = {3016723, 3020564, 3034485, 3022192}


# --- 4. NEGATIVE CONTROLS (The Calibrators) ---
NEGATIVE_CONTROLS = {
    'NC_Ingrown_Nail': {139900},
    'NC_Ankle_Sprain': {4196156},
    'NC_Cataract': {375545},
    'NC_Otitis_Media': {378534},
    'NC_T2DM': {201826}, 
    'NC_Hypertension': {320128},
    'NC_Hyperlipidemia': {432867},
    'NC_Gout': {439392},
    'NC_Depression': {4282316}, 
    'NC_Anxiety': {436073}, 
    'NC_Insomnia': {436962},
    'NC_Osteoarthritis': {4079750, 4155298},
    'NC_Low_Back_Pain': {4213162},
    'NC_Carpal_Tunnel': {376918},
    'NC_Allergic_Rhinitis': {379805},
    'NC_GERD': {192279}, 
    'NC_Migraine': {375527}, 
    'NC_Hypothyroidism': {140673},
    'NC_Varicose_Veins': {318800}
}

ALL_NEGATIVE_CONTROL_IDS = set().union(*NEGATIVE_CONTROLS.values())

# --- 5. EXCLUSION LIST (Covariates to drop) ---
# We ONLY exclude concepts if they are the outcome itself occurring *after* index.
# We do NOT exclude pre-index history of these conditions (they are confounders).
# Specific logic applied in Cell 5.
OUTCOME_CONCEPTS = set().union(*[v[0] for k,v in ANALYSIS_OUTCOMES.items() if isinstance(v[0], set)])

print(f"Definitions Loaded. {len(NEGATIVE_CONTROLS)} Negative Controls defined.")


In [None]:
# =============================================================================
# CELL 2.5: Empirical Null Calibration Engine
# =============================================================================
from scipy.stats import norm

def calibrate_estimates(results_df):
    """
    Fits an Empirical Null distribution to Negative Controls and calibrates
    the P-values and CIs for the Main Outcomes.
    
    Assumption: Negative Controls have True Log-RR = 0.
    """
    print("\n--- Performing Empirical Calibration ---")
    
    # 1. Identify Negative Controls
    # Assumes results_df has a column 'Type' or 'Outcome' starting with 'NC_'
    nc_df = results_df[results_df['Outcome'].str.startswith('NC_')].copy()
    
    if len(nc_df) < 10:
        print("WARNING: Too few negative controls (<10) for robust calibration.")
        return results_df
    
    # 2. Extract Log-RR and Standard Error (from CI)
    # We use Log-RR because it's symmetric. 
    # SE = (Log(Upper) - Log(Lower)) / 3.92
    nc_df['log_rr'] = np.log(nc_df['HR_Cox'].astype(float))
    nc_df['se_log_rr'] = (np.log(nc_df['HR_CI_High'].astype(float)) - np.log(nc_df['HR_CI_Low'].astype(float))) / 3.92
    
    # Drop invalid results (infinite or NaN)
    nc_df = nc_df.replace([np.inf, -np.inf], np.nan).dropna(subset=['log_rr', 'se_log_rr'])
    
    # 3. Fit the Null Distribution N(mu, sigma^2)
    # We use a weighted moment estimator (inverse variance weighting)
    weights = 1.0 / (nc_df['se_log_rr'] ** 2)
    
    # Mean bias (Systematic Shift)
    null_mean = np.average(nc_df['log_rr'], weights=weights)
    
    # SD bias (Unmeasured Confounding width)
    # Variance = weighted average of (x - mean)^2 - average sampling variance
    raw_var = np.average((nc_df['log_rr'] - null_mean)**2, weights=weights)
    expected_sampling_var = np.average(nc_df['se_log_rr']**2, weights=weights)
    
    # The systematic variance is the excess variance observed
    null_var = max(0, raw_var - expected_sampling_var)
    null_sd = np.sqrt(null_var)
    
    print(f"  Empirical Null Fitted: Mean Bias = {null_mean:.4f}, SD Bias = {null_sd:.4f}")
    print(f"  (Interpretation: Mean!=0 implies systematic error; SD>0 implies unmeasured confounding)")
    
    # 4. Calibrate All Estimates (Main + NCs)
    calibrated_results = results_df.copy()
    
    # Calculate Log stats for all rows
    log_rr = np.log(calibrated_results['HR_Cox'].astype(float))
    se_log_rr = (np.log(calibrated_results['HR_CI_High'].astype(float)) - np.log(calibrated_results['HR_CI_Low'].astype(float))) / 3.92
    
    # Calibrated Z-Score
    # We subtract the mean bias and divide by the wider uncertainty (sampling + systematic)
    z_cal = (log_rr - null_mean) / np.sqrt(se_log_rr**2 + null_sd**2)
    
    # Calibrated P-value
    calibrated_results['P_Calibrated'] = 2 * (1 - norm.cdf(np.abs(z_cal)))
    
    # Calibrated CIs (Shifted and Widened)
    calibrated_se = np.sqrt(se_log_rr**2 + null_sd**2)
    calibrated_results['HR_Calibrated'] = np.exp(log_rr - null_mean)
    calibrated_results['HR_Cal_Low'] = np.exp((log_rr - null_mean) - 1.96 * calibrated_se)
    calibrated_results['HR_Cal_High'] = np.exp((log_rr - null_mean) + 1.96 * calibrated_se)
    
    return calibrated_results

# Function to plot the calibration (Funnel Plot)
def plot_calibration(results_df):
    plt.figure(figsize=(10, 6))
    
    # Plot Negative Controls
    ncs = results_df[results_df['Outcome'].str.startswith('NC_')]
    log_rr_nc = np.log(ncs['HR_Cox'].astype(float))
    se_nc = (np.log(ncs['HR_CI_High'].astype(float)) - np.log(ncs['HR_CI_Low'].astype(float))) / 3.92
    
    plt.scatter(log_rr_nc, 1/se_nc, alpha=0.5, color='gray', label='Negative Controls')
    
    # Plot Main Outcomes
    main = results_df[~results_df['Outcome'].str.startswith('NC_')]
    log_rr_main = np.log(main['HR_Cox'].astype(float))
    se_main = (np.log(main['HR_CI_High'].astype(float)) - np.log(main['HR_CI_Low'].astype(float))) / 3.92
    
    plt.scatter(log_rr_main, 1/se_main, color='red', s=100, label='Main Outcomes', zorder=10)
    
    # Plot Null Line (x=0 is HR=1)
    plt.axvline(0, color='black', linestyle='--')
    
    # Plot Fitted Null Area (Mean +/- SD)
    # We just draw the region around 0 to show visual calibration
    # (If using the fitted values, draw vertical lines at null_mean)
    
    plt.xlabel("Log Hazard Ratio")
    plt.ylabel("Precision (1/SE)")
    plt.title("Empirical Calibration Funnel Plot")
    plt.legend()
    plt.show()
    
print("Done")

In [None]:
# =============================================================================
# CELL 2.5: Empirical Null Calibration Engine
# =============================================================================
from scipy.stats import norm

def calibrate_estimates(results_df):
    """
    Fits an Empirical Null distribution to Negative Controls and calibrates
    the P-values and CIs for the Main Outcomes.
    
    Assumption: Negative Controls have True Log-RR = 0.
    """
    print("\n--- Performing Empirical Calibration ---")
    
    # 1. Identify Negative Controls
    # Assumes results_df has a column 'Type' or 'Outcome' starting with 'NC_'
    nc_df = results_df[results_df['Outcome'].str.startswith('NC_')].copy()
    
    if len(nc_df) < 10:
        print("WARNING: Too few negative controls (<10) for robust calibration.")
        return results_df
    
    # 2. Extract Log-RR and Standard Error (from CI)
    # We use Log-RR because it's symmetric. 
    # SE = (Log(Upper) - Log(Lower)) / 3.92
    nc_df['log_rr'] = np.log(nc_df['HR_Cox'].astype(float))
    nc_df['se_log_rr'] = (np.log(nc_df['HR_CI_High'].astype(float)) - np.log(nc_df['HR_CI_Low'].astype(float))) / 3.92
    
    # Drop invalid results (infinite or NaN)
    nc_df = nc_df.replace([np.inf, -np.inf], np.nan).dropna(subset=['log_rr', 'se_log_rr'])
    
    # 3. Fit the Null Distribution N(mu, sigma^2)
    # We use a weighted moment estimator (inverse variance weighting)
    weights = 1.0 / (nc_df['se_log_rr'] ** 2)
    
    # Mean bias (Systematic Shift)
    null_mean = np.average(nc_df['log_rr'], weights=weights)
    
    # SD bias (Unmeasured Confounding width)
    # Variance = weighted average of (x - mean)^2 - average sampling variance
    raw_var = np.average((nc_df['log_rr'] - null_mean)**2, weights=weights)
    expected_sampling_var = np.average(nc_df['se_log_rr']**2, weights=weights)
    
    # The systematic variance is the excess variance observed
    null_var = max(0, raw_var - expected_sampling_var)
    null_sd = np.sqrt(null_var)
    
    print(f"  Empirical Null Fitted: Mean Bias = {null_mean:.4f}, SD Bias = {null_sd:.4f}")
    print(f"  (Interpretation: Mean!=0 implies systematic error; SD>0 implies unmeasured confounding)")
    
    # 4. Calibrate All Estimates (Main + NCs)
    calibrated_results = results_df.copy()
    
    # Calculate Log stats for all rows
    log_rr = np.log(calibrated_results['HR_Cox'].astype(float))
    se_log_rr = (np.log(calibrated_results['HR_CI_High'].astype(float)) - np.log(calibrated_results['HR_CI_Low'].astype(float))) / 3.92
    
    # Calibrated Z-Score
    # We subtract the mean bias and divide by the wider uncertainty (sampling + systematic)
    z_cal = (log_rr - null_mean) / np.sqrt(se_log_rr**2 + null_sd**2)
    
    # Calibrated P-value
    calibrated_results['P_Calibrated'] = 2 * (1 - norm.cdf(np.abs(z_cal)))
    
    # Calibrated CIs (Shifted and Widened)
    calibrated_se = np.sqrt(se_log_rr**2 + null_sd**2)
    calibrated_results['HR_Calibrated'] = np.exp(log_rr - null_mean)
    calibrated_results['HR_Cal_Low'] = np.exp((log_rr - null_mean) - 1.96 * calibrated_se)
    calibrated_results['HR_Cal_High'] = np.exp((log_rr - null_mean) + 1.96 * calibrated_se)
    
    return calibrated_results

# Function to plot the calibration (Funnel Plot)
def plot_calibration(results_df):
    plt.figure(figsize=(10, 6))
    
    # Plot Negative Controls
    ncs = results_df[results_df['Outcome'].str.startswith('NC_')]
    log_rr_nc = np.log(ncs['HR_Cox'].astype(float))
    se_nc = (np.log(ncs['HR_CI_High'].astype(float)) - np.log(ncs['HR_CI_Low'].astype(float))) / 3.92
    
    plt.scatter(log_rr_nc, 1/se_nc, alpha=0.5, color='gray', label='Negative Controls')
    
    # Plot Main Outcomes
    main = results_df[~results_df['Outcome'].str.startswith('NC_')]
    log_rr_main = np.log(main['HR_Cox'].astype(float))
    se_main = (np.log(main['HR_CI_High'].astype(float)) - np.log(main['HR_CI_Low'].astype(float))) / 3.92
    
    plt.scatter(log_rr_main, 1/se_main, color='red', s=100, label='Main Outcomes', zorder=10)
    
    # Plot Null Line (x=0 is HR=1)
    plt.axvline(0, color='black', linestyle='--')
    
    # Plot Fitted Null Area (Mean +/- SD)
    # We just draw the region around 0 to show visual calibration
    # (If using the fitted values, draw vertical lines at null_mean)
    
    plt.xlabel("Log Hazard Ratio")
    plt.ylabel("Precision (1/SE)")
    plt.title("Empirical Calibration Funnel Plot")
    plt.legend()
    plt.show()
    
print("Done")

In [None]:
# =============================================================================
# CELL 4: Clinical Specifics (Labs & Outcomes) [Chunked Version]
# =============================================================================
print("\n--- CELL 4: Clinical Specifics (Labs & Outcomes) ---")

# --- Helper Function for Chunking ---
def get_data_in_chunks(sql_template, all_ids, chunk_size=5000):
    """
    Splits the list of person_ids into smaller chunks to avoid BigQuery
    query length limits (1MB).
    """
    results = []
    ids_list = sorted(list(set(all_ids))) # Ensure unique and list format
    
    print(f"  Fetching data for {len(ids_list)} patients in chunks of {chunk_size}...")
    
    # Loop through chunks
    for i in range(0, len(ids_list), chunk_size):
        chunk = ids_list[i : i + chunk_size]
        chunk_str = "(" + ",".join(map(str, chunk)) + ")"
        
        # Inject the chunk of IDs into the placeholder
        query = sql_template.replace("PLACEHOLDER_IDS", chunk_str)
        
        try:
            df_chunk = read_gbq(query, dialect="standard")
            results.append(df_chunk)
        except Exception as e:
            print(f"    Error in chunk {i}: {e}")
            
    if not results:
        return pd.DataFrame()
        
    return pd.concat(results, ignore_index=True)

# List of all patients in cohort
all_cohort_ids = df_cohort.index.tolist()

# --- A. Measurements: Strict eGFR vs Creatinine Separation ---

# 1. Fetch eGFR
# Note: We use PLACEHOLDER_IDS instead of injecting the huge list immediately
sql_egfr_template = f"""
SELECT person_id, measurement_date as date, value_as_number
FROM `{WORKSPACE_CDR}.measurement`
WHERE measurement_concept_id IN ({','.join(map(str, EGFR_CONCEPTS))})
AND person_id IN PLACEHOLDER_IDS
AND value_as_number > 0 AND value_as_number < 200
"""
print("Fetching eGFR...")
df_egfr_raw = get_data_in_chunks(sql_egfr_template, all_cohort_ids)
df_egfr_raw['date'] = to_naive_utc_day(df_egfr_raw['date'])

# 2. Fetch Creatinine
sql_creat_template = f"""
SELECT person_id, measurement_date as date, value_as_number
FROM `{WORKSPACE_CDR}.measurement`
WHERE measurement_concept_id IN ({','.join(map(str, CREATININE_CONCEPTS))})
AND person_id IN PLACEHOLDER_IDS
AND value_as_number > 0.1 AND value_as_number < 20
"""
print("Fetching Creatinine...")
df_creat_raw = get_data_in_chunks(sql_creat_template, all_cohort_ids)
df_creat_raw['date'] = to_naive_utc_day(df_creat_raw['date'])

# Merge to find baseline (Closest prior to index)
df_dates = df_cohort[['index_date']].reset_index()

# Function to get last value before index
def get_baseline_lab(df_lab, df_index, col_name):
    if df_lab.empty:
        return pd.Series(dtype=float)
    merged = df_lab.merge(df_index, on='person_id')
    # Strictly prior to index
    merged = merged[merged['date'] < merged['index_date']].sort_values('date')
    return merged.groupby('person_id')['value_as_number'].last().rename(col_name)

df_cohort['baseline_egfr'] = get_baseline_lab(df_egfr_raw, df_dates, 'baseline_egfr')
df_cohort['baseline_creat'] = get_baseline_lab(df_creat_raw, df_dates, 'baseline_creat')

# Categorical Definitions
conditions = [
    (df_cohort['baseline_egfr'] < 30),
    (df_cohort['baseline_egfr'] >= 30) & (df_cohort['baseline_egfr'] < 45),
    (df_cohort['baseline_egfr'] >= 45) & (df_cohort['baseline_egfr'] < 60),
    (df_cohort['baseline_egfr'] >= 60)
]
df_cohort['egfr_cat'] = np.select(conditions, [0, 1, 2, 3], default=4) # 4 is Missing

# --- B. Outcomes & Pre-Existing Conditions ---
all_outcome_concepts = OUTCOME_CONCEPTS.union(THYROTOXICOSIS_IDS).union(ALL_NEGATIVE_CONTROL_IDS)

# Note: We construct the UNION ALL inside the template, but both parts need PLACEHOLDER_IDS
sql_outcomes_template = f"""
SELECT person_id, condition_start_date as event_date, condition_concept_id
FROM `{WORKSPACE_CDR}.condition_occurrence`
WHERE condition_concept_id IN ({','.join(map(str, all_outcome_concepts))})
AND person_id IN PLACEHOLDER_IDS
UNION ALL
SELECT person_id, death_date as event_date, 0 as condition_concept_id
FROM `{WORKSPACE_CDR}.death`
WHERE person_id IN PLACEHOLDER_IDS
"""
print("Fetching Outcomes & Events...")
df_events = get_data_in_chunks(sql_outcomes_template, all_cohort_ids)
df_events['event_date'] = to_naive_utc_day(df_events['event_date'])
df_events = df_events.merge(df_dates, on='person_id')

# 1. Pre-existing Thyrotoxicosis
pre_thyro = df_events[
    (df_events['condition_concept_id'].isin(THYROTOXICOSIS_IDS)) & 
    (df_events['event_date'] < df_events['index_date'])
]
df_cohort['hx_thyrotoxicosis'] = 0
df_cohort.loc[df_cohort.index.isin(pre_thyro['person_id']), 'hx_thyrotoxicosis'] = 1

# 2. Standardized Outcome Dates
# Death
df_death = df_events[df_events['condition_concept_id'] == 0]
df_cohort['date_DEATH'] = df_death.groupby('person_id')['event_date'].min()

# Map Specific Outcomes
for outcome, (concepts, window) in ANALYSIS_OUTCOMES.items():
    if outcome in ['MORTALITY_30', 'MAE_30']: continue 
    
    events = df_events[
        (df_events['condition_concept_id'].isin(concepts)) & 
        (df_events['event_date'] >= df_events['index_date'])
    ]
    df_cohort[f'date_{outcome}'] = events.groupby('person_id')['event_date'].min()

# Composite MAE_30
mae_cols = ['date_AKI_30', 'date_NEW_DIALYSIS_90', 'date_DEATH']
mae_valid = [c for c in mae_cols if c in df_cohort.columns]
if mae_valid:
    df_cohort['date_MAE_30'] = df_cohort[mae_valid].min(axis=1)

# Negative Controls
for name, concepts in NEGATIVE_CONTROLS.items():
    events = df_events[
        (df_events['condition_concept_id'].isin(concepts)) & 
        (df_events['event_date'] >= df_events['index_date'])
    ]
    df_cohort[f'date_{name}'] = events.groupby('person_id')['event_date'].min()

print("Clinical Specifics Attached (Chunked).")
clean_mem()

In [None]:
# =============================================================================
# CELL 5: High-Dimensional Feature Extraction (Robust High-RAM Version)
# =============================================================================
import pandas as pd
import numpy as np
import scipy.sparse as sp
from pandas_gbq import read_gbq
import gc

print("\n--- CELL 5: High-Dimensional Feature Extraction ---")

# 1. Setup & Helper Functions
# ---------------------------------------------------------
# Define Exclusions (Exposure itself)
exclusions = ALL_IMAGING 
exclude_str = f"({','.join(map(str, exclusions))})"

# Get list of ALL patient IDs from the cohort dataframe
# This fixes the NameError
all_cohort_ids = df_cohort.index.tolist()

def get_data_in_chunks(sql_template, all_ids, chunk_size=4000):
    """Downloads data in chunks to satisfy BigQuery 1MB query limit."""
    results = []
    ids_list = sorted(list(set(all_ids)))
    print(f"  Downloading data for {len(ids_list)} patients (Chunks of {chunk_size})...")
    
    for i in range(0, len(ids_list), chunk_size):
        chunk = ids_list[i : i + chunk_size]
        # Robust string formatting for the IN clause
        chunk_str = "(" + ",".join(map(str, chunk)) + ")"
        query = sql_template.replace("PLACEHOLDER_IDS", chunk_str)
        try:
            df_chunk = read_gbq(query, dialect="standard")
            results.append(df_chunk)
            print(f"    Chunk {i//chunk_size + 1} downloaded ({len(df_chunk)} rows)")
        except Exception as e:
            print(f"    Error in chunk {i}: {e}")
            
    if not results: return pd.DataFrame()
    return pd.concat(results, ignore_index=True)

# 2. SQL Construction (Template)
# ---------------------------------------------------------
# We use PLACEHOLDER_IDS instead of injecting the huge list
sql_features_template = f"""
WITH Cohort AS (
    SELECT 
        p.person_id, 
        CAST(p.procedure_datetime AS DATE) as index_date
    FROM `{WORKSPACE_CDR}.procedure_occurrence` p
    WHERE p.person_id IN PLACEHOLDER_IDS
    AND p.procedure_concept_id IN ({','.join(map(str, ALL_IMAGING))})
)
SELECT 
    c.person_id, 
    CAST(c.condition_concept_id AS STRING) as feature_id, 
    'COND' as domain
FROM `{WORKSPACE_CDR}.condition_occurrence` c
JOIN Cohort i ON c.person_id = i.person_id
WHERE c.condition_start_date < i.index_date 
  AND c.condition_concept_id NOT IN {exclude_str}

UNION ALL

SELECT 
    d.person_id, 
    CAST(d.drug_concept_id AS STRING) as feature_id, 
    'DRUG' as domain
FROM `{WORKSPACE_CDR}.drug_exposure` d
JOIN Cohort i ON d.person_id = i.person_id
WHERE d.drug_exposure_start_date < i.index_date
  AND d.drug_concept_id NOT IN {exclude_str}

UNION ALL

SELECT 
    p.person_id, 
    CAST(p.procedure_concept_id AS STRING) as feature_id, 
    'PROC' as domain
FROM `{WORKSPACE_CDR}.procedure_occurrence` p
JOIN Cohort i ON p.person_id = i.person_id
WHERE p.procedure_date < i.index_date
  AND p.procedure_concept_id NOT IN {exclude_str}

UNION ALL

SELECT 
    m.person_id, 
    CAST(m.measurement_concept_id AS STRING) as feature_id, 
    'MEAS' as domain
FROM `{WORKSPACE_CDR}.measurement` m
JOIN Cohort i ON m.person_id = i.person_id
WHERE m.measurement_date < i.index_date
  AND m.measurement_concept_id NOT IN {exclude_str}

UNION ALL

SELECT 
    o.person_id, 
    CONCAT(CAST(o.observation_concept_id AS STRING), '_', CAST(COALESCE(o.value_as_concept_id, 0) AS STRING)) as feature_id, 
    'OBS' as domain
FROM `{WORKSPACE_CDR}.observation` o
JOIN Cohort i ON o.person_id = i.person_id
WHERE o.observation_date < i.index_date
  AND o.observation_concept_id NOT IN {exclude_str}

UNION ALL

SELECT 
    dv.person_id, 
    CAST(dv.device_concept_id AS STRING) as feature_id, 
    'DEV' as domain
FROM `{WORKSPACE_CDR}.device_exposure` dv
JOIN Cohort i ON dv.person_id = i.person_id
WHERE dv.device_exposure_start_date < i.index_date
  AND dv.device_concept_id NOT IN {exclude_str}
"""

# 3. Execution (Chunked Download + High-Speed Processing)
# ---------------------------------------------------------
print("Starting High-RAM Feature Extraction...")

# A. Download
df_features = get_data_in_chunks(sql_features_template, all_cohort_ids)

# B. Process (Vectorized)
if not df_features.empty:
    print(f"Total raw features rows: {len(df_features):,}")
    
    # Create PID mapping if missing
    if 'pid_to_idx' not in locals():
        pid_to_idx = {pid: i for i, pid in enumerate(df_cohort.index)}
    
    # Filter to cohort (safety)
    df_features = df_features[df_features['person_id'].isin(pid_to_idx)].copy()
    
    # Vectorized string concat (Fast)
    df_features['feature_name'] = df_features['domain'] + '_' + df_features['feature_id']
    
    # Count
    feature_counts = df_features['feature_name'].value_counts()
    
    # Filter Prevalence >= 50
    valid_feats_set = set(feature_counts[feature_counts >= 50].index)
    feat_to_idx = {feat: i for i, feat in enumerate(sorted(list(valid_feats_set)))}
    
    print(f"Unique Features: {len(feature_counts)}. Retained (>=50): {len(valid_feats_set)}")
    
    # Final Filter
    df_features_valid = df_features[df_features['feature_name'].isin(valid_feats_set)]
    
    # C. Build Matrix
    print("Building Sparse Matrix...")
    row_indices = df_features_valid['person_id'].map(pid_to_idx).values
    col_indices = df_features_valid['feature_name'].map(feat_to_idx).values
    values = np.ones(len(row_indices))
    
    X_sparse = sp.coo_matrix(
        (values, (row_indices, col_indices)),
        shape=(len(df_cohort), len(valid_feats_set))
    ).tocsr()
    
    # Binarize
    X_sparse.data = np.ones_like(X_sparse.data)
    
    print(f"Final Sparse Matrix Shape: {X_sparse.shape}")
    
    # Clean Memory
    del df_features, df_features_valid, row_indices, col_indices
    gc.collect()
    
else:
    print("CRITICAL WARNING: No feature history found for this cohort.")
    X_sparse = sp.csr_matrix((len(df_cohort), 0))

print("Done.")

In [None]:
import os
import psutil

def print_disk_usage(tag=""):
    """Print total, used, and free disk space for the main volume."""
    du = psutil.disk_usage('/')
    total_gb = du.total / (1024**3)
    used_gb  = du.used  / (1024**3)
    free_gb  = du.free  / (1024**3)
    pct = du.percent
    print(f"[DISK] {tag}")
    print(f"       Total: {total_gb:6.1f} GB")
    print(f"       Used : {used_gb:6.1f} GB ({pct:4.1f}%)")
    print(f"       Free : {free_gb:6.1f} GB")
    print("")

def list_dir_with_sizes(path=".", max_files=50):
    """List files in a directory with sizes in MB, capped at max_files."""
    print(f"[FILES] Listing '{os.path.abspath(path)}' (showing up to {max_files} entries):")
    files = os.listdir(path)
    for fname in files[:max_files]:
        fpath = os.path.join(path, fname)
        if os.path.isfile(fpath):
            size_mb = os.path.getsize(fpath) / (1024**2)
            print(f"   {fname:40s} {size_mb:10.2f} MB")
        else:
            print(f"   {fname:40s} <DIR>")
    if len(files) > max_files:
        print(f"   ... ({len(files)-max_files} more files)")
    print("")


In [None]:
import os
import psutil

def print_disk_usage(tag=""):
    """Print total, used, and free disk space for the main volume."""
    du = psutil.disk_usage('/')
    total_gb = du.total / (1024**3)
    used_gb  = du.used  / (1024**3)
    free_gb  = du.free  / (1024**3)
    pct = du.percent
    print(f"[DISK] {tag}")
    print(f"       Total: {total_gb:6.1f} GB")
    print(f"       Used : {used_gb:6.1f} GB ({pct:4.1f}%)")
    print(f"       Free : {free_gb:6.1f} GB")
    print("")

def list_dir_with_sizes(path=".", max_files=50):
    """List files in a directory with sizes in MB, capped at max_files."""
    print(f"[FILES] Listing '{os.path.abspath(path)}' (showing up to {max_files} entries):")
    files = os.listdir(path)
    for fname in files[:max_files]:
        fpath = os.path.join(path, fname)
        if os.path.isfile(fpath):
            size_mb = os.path.getsize(fpath) / (1024**2)
            print(f"   {fname:40s} {size_mb:10.2f} MB")
        else:
            print(f"   {fname:40s} <DIR>")
    if len(files) > max_files:
        print(f"   ... ({len(files)-max_files} more files)")
    print("")


MIDPOINT!!

In [None]:
import pandas as pd
import joblib
import scipy.sparse as sp


import os
import psutil

def print_disk_usage(tag=""):
    """Print total, used, and free disk space for the main volume."""
    du = psutil.disk_usage('/')
    total_gb = du.total / (1024**3)
    used_gb  = du.used  / (1024**3)
    free_gb  = du.free  / (1024**3)
    pct = du.percent
    print(f"[DISK] {tag}")
    print(f"       Total: {total_gb:6.1f} GB")
    print(f"       Used : {used_gb:6.1f} GB ({pct:4.1f}%)")
    print(f"       Free : {free_gb:6.1f} GB")
    print("")

def list_dir_with_sizes(path=".", max_files=50):
    """List files in a directory with sizes in MB, capped at max_files."""
    print(f"[FILES] Listing '{os.path.abspath(path)}' (showing up to {max_files} entries):")
    files = os.listdir(path)
    for fname in files[:max_files]:
        fpath = os.path.join(path, fname)
        if os.path.isfile(fpath):
            size_mb = os.path.getsize(fpath) / (1024**2)
            print(f"   {fname:40s} {size_mb:10.2f} MB")
        else:
            print(f"   {fname:40s} <DIR>")
    if len(files) > max_files:
        print(f"   ... ({len(files)-max_files} more files)")
    print("")


print("=== RELOADING CHECKPOINT ===")
print_disk_usage("Before loading")
list_dir_with_sizes(".")


df_cohort   = pd.read_parquet("df_cohort.parquet")
X_sparse    = sp.load_npz("X_sparse.npz")
pid_to_idx  = joblib.load("pid_to_idx.joblib")
feat_to_idx = joblib.load("feat_to_idx.joblib")

print("Reload complete.")

print_disk_usage("After loading")
list_dir_with_sizes(".")

print("df_cohort shape:", df_cohort.shape)
print("X_sparse shape :", X_sparse.shape)

In [None]:
import pandas as pd
import scipy.sparse as sp
import joblib
import numpy as np
import os

print("=== Sanity check: reload from disk into new variables ===")
print("Files on disk:", os.listdir("."))

# 1. Reload
df_cohort_disk = pd.read_parquet("df_cohort.parquet")
X_sparse_disk  = sp.load_npz("X_sparse.npz")
pid_to_idx_disk  = joblib.load("pid_to_idx.joblib")
feat_to_idx_disk = joblib.load("feat_to_idx.joblib")

print("df_cohort RAM shape :", df_cohort.shape)
print("df_cohort DISK shape:", df_cohort_disk.shape)

print("X_sparse RAM shape  :", X_sparse.shape)
print("X_sparse DISK shape :", X_sparse_disk.shape)

print("pid_to_idx sizes    :", len(pid_to_idx), len(pid_to_idx_disk))
print("feat_to_idx sizes   :", len(feat_to_idx), len(feat_to_idx_disk))

# 2. Quick content checks (cheap)
print("df_cohort columns equal?:", list(df_cohort.columns) == list(df_cohort_disk.columns))
print("First 5 rows equal?     :", df_cohort.head().reset_index(drop=True).equals(
    df_cohort_disk.head().reset_index(drop=True)
))

# Sparse: check a few random rows instead of full matrix diff
rng = np.random.default_rng(42)
idx_sample = rng.choice(X_sparse.shape[0], size=5, replace=False)

ok_sparse = True
for i in idx_sample:
    row_ram  = X_sparse[i].toarray()
    row_disk = X_sparse_disk[i].toarray()
    if not np.array_equal(row_ram, row_disk):
        ok_sparse = False
        print(f"Row {i} mismatch")
        break

print("Sampled sparse rows equal?:", ok_sparse)


In [None]:
# =============================================================================
# CELL 6: Feature Engineering & Alignment (Optimized + TQDM)
# =============================================================================
print("\n--- CELL 6: Feature Engineering & Alignment ---")

import scipy.sparse as sp
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm

# Setup Progress Bar
pbar = tqdm(total=6, desc="Initializing")

# 1. Align Dataframe to Sparse Matrix
pbar.set_description("Aligning Dataframes")
if 'pid_to_idx' not in locals():
    raise ValueError("Missing pid_to_idx! Please run Cell 5.5.")

sorted_pids = sorted(pid_to_idx, key=pid_to_idx.get)
df_cohort_aligned = df_cohort.loc[sorted_pids].copy()

# Safety Check
assert len(df_cohort_aligned) == X_sparse.shape[0], "Row count mismatch between Dense and Sparse!"
assert df_cohort_aligned.index[0] == sorted_pids[0], "Index alignment error!"

print(f"Aligned Cohort N={len(df_cohort_aligned)}")
pbar.update(1)

# 2. Feature Engineering: Dense Covariates

# A. Site Rate Smoothing
pbar.set_description("Smoothing Site Rates")
if 'zip_code' in df_cohort_aligned.columns:
    # Use float32 to save RAM if dataset is massive, otherwise default float is fine
    site_counts = df_cohort_aligned.groupby('zip_code')['contrast_received'].agg(['mean', 'count'])
    global_mean = df_cohort_aligned['contrast_received'].mean()
    C_smooth = 10 
    site_counts['smoothed_rate'] = (site_counts['mean'] * site_counts['count'] + global_mean * C_smooth) / (site_counts['count'] + C_smooth)
    df_cohort_aligned['site_contrast_rate'] = df_cohort_aligned['zip_code'].map(site_counts['smoothed_rate']).fillna(global_mean)
else:
    df_cohort_aligned['site_contrast_rate'] = df_cohort_aligned['contrast_received'].mean()
pbar.update(1)

# B. Imputation & Scaling
pbar.set_description("Imputing & Scaling")
# Handle Imputation
for col in ['baseline_egfr', 'baseline_creat']:
    if col in df_cohort_aligned.columns:
        df_cohort_aligned[f'{col}_missing'] = df_cohort_aligned[col].isna().astype(float) # Direct to float
        df_cohort_aligned[f'{col}_imputed'] = df_cohort_aligned[col].fillna(df_cohort_aligned[col].median())
    else:
        df_cohort_aligned[f'{col}_missing'] = 1.0
        df_cohort_aligned[f'{col}_imputed'] = 0.0

# Handle Scaling
dense_cols_to_scale = ['age', 'site_contrast_rate', 'baseline_egfr_imputed', 'baseline_creat_imputed']
dense_cols_to_scale = [c for c in dense_cols_to_scale if c in df_cohort_aligned.columns]

scaler = StandardScaler()
# Output directly as float
X_dense_scaled = scaler.fit_transform(df_cohort_aligned[dense_cols_to_scale].fillna(0))
pbar.update(1)

# C. Categorical Dummies 
# OPTIMIZATION: dtype=float here prevents a massive copy/cast later
pbar.set_description("Generating Dummies")

if 'egfr_cat' in df_cohort_aligned.columns:
    egfr_dummies = pd.get_dummies(df_cohort_aligned['egfr_cat'], prefix='egfr_cat', dtype=float)
else:
    egfr_dummies = pd.DataFrame()

df_cohort_aligned['age_decile'] = pd.qcut(df_cohort_aligned['age'], q=10, labels=False, duplicates='drop')
age_dummies = pd.get_dummies(df_cohort_aligned['age_decile'], prefix='age_decile', dtype=float)

gender_dummies = pd.get_dummies(df_cohort_aligned['gender_concept_id'], prefix='gender', dtype=float)

# Direct float conversion
thyro_dummy = df_cohort_aligned[['hx_thyrotoxicosis']].astype(float)
missing_flags = df_cohort_aligned[[c for c in df_cohort_aligned.columns if c.endswith('_missing')]].astype(float)
pbar.update(1)

# 3. Combine All Dense Features
pbar.set_description("Concatenating Dense Matrix")
X_dense_list = [
    pd.DataFrame(X_dense_scaled, index=df_cohort_aligned.index, columns=dense_cols_to_scale),
    egfr_dummies,
    age_dummies,
    gender_dummies,
    thyro_dummy,
    missing_flags
]

# OPTIMIZATION: No .astype(float) needed here anymore, saving one full memory write
df_dense_final = pd.concat(X_dense_list, axis=1)
X_dense = sp.csr_matrix(df_dense_final.values)
pbar.update(1)

# 4. Final Stack (Dense + Sparse)
pbar.set_description("Final Stack (HStack)")
X_all = sp.hstack([X_dense, X_sparse], format='csr')
pbar.update(1)
pbar.close()

clean_mem()
print(f"Final Input Matrix Shape: {X_all.shape}")
print(f"Dense Features included ({df_dense_final.shape[1]}): {list(df_dense_final.columns[:10])} ...")


In [None]:
# --- CELL 7A: Hyperparameter Tuning on 10k Subsample ---

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCV
import numpy as np

T = df_cohort_aligned['contrast_received'].values
N = len(T)

print(f"Full cohort: N={N}, P={X_all.shape[1]}")

# --- 1. Choose a ~10k stratified subsample ---
target_n = 10_000
if N <= target_n:
    print("N <= 10k, using full cohort for tuning.")
    idx_sub = np.arange(N)
else:
    frac = target_n / N
    sss = StratifiedShuffleSplit(
        n_splits=1,
        test_size=frac,
        random_state=42
    )
    _, idx_sub = next(sss.split(X_all, T))
    print(f"Using subsample of size {len(idx_sub)} for tuning.")

X_sub = X_all[idx_sub]
T_sub = T[idx_sub]

# --- 2. Define base logistic model ---
# saga supports L1/L2 with large, possibly sparse, high-dim data.
base_logit = LogisticRegression(
    penalty='l1',              # we'll start with L1; could swap to 'l2'
    solver='saga',             # good for large P, supports l1
    max_iter=3000,
    class_weight='balanced',   # very standard for PS in imbalanced settings
    n_jobs=-1,                 # use all visible cores inside each fit
    random_state=42
)

# --- 3. Grid over C (log-scale) ---
param_grid = {
    "C": np.logspace(-3, 0, 6)   # 0.001, 0.0032, 0.01, 0.032, 0.1, 1.0
}

print("Tuning C on subsample using 3-fold CV (scoring=roc_auc)...")
gs = GridSearchCV(
    estimator=base_logit,
    param_grid=param_grid,
    cv=3,
    scoring="roc_auc",
    n_jobs=-1,
    verbose=2
)

gs.fit(X_sub, T_sub)

best_C = gs.best_params_["C"]
best_score = gs.best_score_

print(f"Best C: {best_C:.4g} (mean CV AUC: {best_score:.4f})")

# We'll reuse this C downstream
C_best = float(best_C)


In [None]:
# --- CELL 7B: Cross-Fitted PS (Parallel Over Folds, Tunable Core Use) ---

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
from joblib import Parallel, delayed
import numpy as np
import time
import os

T = df_cohort_aligned['contrast_received'].values
N, P = X_all.shape
print(f"Training cross-fitted PS on full cohort (N={N}, P={P}, C={C_best})")

outer_k = 5
cv_outer = StratifiedKFold(
    n_splits=outer_k,
    shuffle=True,
    random_state=42
)

# ---- CORE / THREAD SETTINGS ----
threads_per_fold = 6  
n_folds_workers = min(outer_k, 32 // threads_per_fold or 1)

print(f"Using up to {n_folds_workers} folds in parallel, "
      f"{threads_per_fold} threads per fold (target ~{n_folds_workers * threads_per_fold} cores)")

os.environ["OMP_NUM_THREADS"] = str(threads_per_fold)
os.environ["MKL_NUM_THREADS"] = str(threads_per_fold)

def fit_one_fold(fold_id, train_idx, test_idx, X, T, C_best):
    """Train L1-logistic on one fold and return PS for its test indices."""
    t0 = time.time()
    print(f"[Fold {fold_id}] start: train={len(train_idx)}, test={len(test_idx)}")

    logit = LogisticRegression(
        penalty='l1',
        solver='saga',
        C=C_best,
        max_iter=1500,
        class_weight='balanced',
        n_jobs=threads_per_fold,   # <-- threads per fold
        random_state=42 + fold_id
    )
    logit.fit(X[train_idx], T[train_idx])
    ps_fold = logit.predict_proba(X[test_idx])[:, 1]

    t1 = time.time()
    print(f"[Fold {fold_id}] done in {(t1 - t0)/60:.2f} min")

    return fold_id, test_idx, ps_fold

# Prepare fold tasks
fold_tasks = [
    (fold_id, train_idx, test_idx)
    for fold_id, (train_idx, test_idx)
    in enumerate(cv_outer.split(X_all, T), start=1)
]

# Run folds in parallel
from joblib import Parallel, delayed

print(f"Running {outer_k} folds with joblib (n_jobs={n_folds_workers})...")
results = Parallel(
    n_jobs=n_folds_workers,
    verbose=10,
    backend="loky"
)(
    delayed(fit_one_fold)(fold_id, train_idx, test_idx, X_all, T, C_best)
    for (fold_id, train_idx, test_idx) in fold_tasks
)

# Merge PS
ps = np.zeros(N, dtype=float)
for fold_id, test_idx, ps_fold in results:
    ps[test_idx] = ps_fold

df_cohort_aligned['ps'] = ps

print("Fitting final global model on full data for interpretation...")
lsps_model = LogisticRegression(
    penalty='l1',
    solver='saga',
    C=C_best,
    max_iter=1000,
    class_weight='balanced',
    n_jobs=-1,
    random_state=999
)
lsps_model.fit(X_all, T)

print("Cross-fitting complete.")

In [None]:
import joblib

print("\n=== CHECKPOINT: Saving Model & Cross-Fitted Scores ===")

# 1. Save the Global Model (for coefficients/interpretation)
joblib.dump(lsps_model, 'lsps_model_global.joblib')
print("Saved Global Logistic Model (lsps_model_global.joblib)")

# 2. Save the Dataframe WITH the 'ps' column
# We need to save the aligned version because it matches the X matrix rows
# and contains the critical 'ps' and 'iptw' columns.
df_cohort_aligned.to_parquet('df_cohort_with_ps.parquet')
print("Saved Cohort with PS scores (df_cohort_with_ps.parquet)")

# 3. Save the C_best parameter (just in case)
joblib.dump(C_best, 'C_best_param.joblib')

print("Checkpoint Complete.")

In [None]:
# =============================================================================
# CELL 7_LOAD: Reload Model & Scores (Skip Training)
# =============================================================================
import joblib
import pandas as pd
import numpy as np

print("--- RELOADING PS MODEL & SCORES ---")

# 1. Load the Dataframe with PS
# This restores the Cross-Fitted scores (crucial for valid inference)
df_cohort_aligned = pd.read_parquet('df_cohort_with_ps.parquet')
print(f"Restored dataframe with PS. Shape: {df_cohort_aligned.shape}")

# 2. Load the Global Model
# This restores the coefficients for interpretation
lsps_model = joblib.load('lsps_model_global.joblib')
print("Restored Global Model.")

# 3. Load Hyperparams
try:
    C_best = joblib.load('C_best_param.joblib')
    print(f"Restored C_best: {C_best}")
except:
    print("C_best not found, setting default.")
    C_best = 0.1 # Default fallback

# 4. Consistency Check
# Ensure the dataframe aligns with the X_all matrix you just generated in Cell 6
# (This assumes you ran Cells 1-6 first)
if 'X_all' in locals():
    assert len(df_cohort_aligned) == X_all.shape[0], "Row count mismatch! Did you re-run Cell 6?"
    
    # Re-define T for downstream cells
    T = df_cohort_aligned['contrast_received'].values
    print("Consistency check passed. Ready for Diagnostics (Cell 7C).")
else:
    print("WARNING: X_all variable not found in RAM. Please run Cell 6 (Feature Engineering) before proceeding to Love Plots.")

In [None]:
# --- CELL 7C: Diagnostics, Trimming, Weights, ESS ---

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# PS diagnostics
plt.figure(figsize=(10,4))
sns.kdeplot(
    df_cohort_aligned.loc[df_cohort_aligned['contrast_received'] == 0, 'ps'],
    fill=True, alpha=0.3, label='Control'
)
sns.kdeplot(
    df_cohort_aligned.loc[df_cohort_aligned['contrast_received'] == 1, 'ps'],
    fill=True, alpha=0.3, label='Treated'
)
plt.title("Propensity Score Overlap (Cross-Fitted, saga/l1)")
plt.legend()
plt.show()

# Trimming
ps = df_cohort_aligned['ps'].values
mask_keep = (ps > 0.025) & (ps < 0.975)
df_final = df_cohort_aligned[mask_keep].copy()
X_final = X_all[mask_keep]
T_final = T[mask_keep]
ps_final = ps[mask_keep]

print(f"Original N: {N}")
print(f"Trimmed  N: {len(df_final)} (Removed {N - len(df_final)})")

# IPW weights (stabilized)
p_t = T_final.mean()
weights = np.where(
    T_final == 1,
    p_t / ps_final,
    (1 - p_t) / (1 - ps_final)
)
df_final['iptw'] = weights

# Effective sample size
ess = calculate_ess(weights)
print(f"Effective Sample Size (ESS): {ess:.0f}")


In [None]:
# --- DIAGNOSTIC ---
import pandas as pd
import numpy as np

# 1. Get Feature Names
# Dense names from Cell 6
dense_names = list(df_dense_final.columns)
# Sparse names from the mapping in Cell 5.5
# We need to invert the feat_to_idx dictionary
idx_to_feat = {v: k for k, v in feat_to_idx.items()}
sparse_names = [idx_to_feat[i] for i in range(len(feat_to_idx))]
all_feature_names = dense_names + sparse_names

# 2. Get Coefficients from the Lasso Model
# The model is 'lsps_cv' from Cell 7
coefs = lsps_model.coef_[0]

# 3. Sort and Display
coef_df = pd.DataFrame({
    'Feature': all_feature_names,
    'Coefficient': coefs,
    'Abs_Coef': np.abs(coefs)
})

print("\n--- TOP 30 PREDICTORS OF RECEIVING CONTRAST ---")
print(coef_df.sort_values('Coefficient', ascending=False).head(30)[['Feature', 'Coefficient']])

print("\n--- TOP 30 PREDICTORS OF WITHHOLDING CONTRAST ---")
print(coef_df.sort_values('Coefficient', ascending=True).head(30)[['Feature', 'Coefficient']])

In [None]:
# --- CELL 7.5: POST-WEIGHTING EQUIPOISE & OVERLAP DIAGNOSTICS ---
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_overlap_coefficient(data0, data1, weights0, weights1, bins=100):
    """Calculates the overlapping area of two weighted density distributions (0 to 1)."""
    # Create common bin edges
    min_val = min(data0.min(), data1.min())
    max_val = max(data0.max(), data1.max())
    bins_edges = np.linspace(min_val, max_val, bins)
    
    # Calculate weighted histograms
    hist0, _ = np.histogram(data0, bins=bins_edges, weights=weights0, density=True)
    hist1, _ = np.histogram(data1, bins=bins_edges, weights=weights1, density=True)
    
    # Calculate intersection area (approximate integration)
    bin_width = bins_edges[1] - bins_edges[0]
    overlap_area = np.sum(np.minimum(hist0, hist1)) * bin_width
    return overlap_area

# Prepare Data
ps_control = df_final[df_final['contrast_received']==0]['ps']
w_control = df_final[df_final['contrast_received']==0]['iptw']
ps_treated = df_final[df_final['contrast_received']==1]['ps']
w_treated = df_final[df_final['contrast_received']==1]['iptw']

# Calculate Metrics
overlap_score = calculate_overlap_coefficient(ps_control, ps_treated, w_control, w_treated)

print(f"\n--- EQUIPOISE DIAGNOSTICS ---")
print(f"Distribution Overlap Coefficient: {overlap_score:.3f} (0=Separated, 1=Perfect Match)")
if overlap_score < 0.1:
    print("WARNING: Poor overlap. Estimates relies heavily on extrapolation.")
elif overlap_score > 0.5:
    print("SUCCESS: Strong clinical equipoise between groups.")

# --- PLOT ---
plt.figure(figsize=(12, 6))

# Plot 1: Unweighted (Raw Propensity)
# We use x=... explicitly here too for consistency
plt.subplot(1, 2, 1)
sns.kdeplot(x=ps_control, fill=True, label='Withheld (Raw)', color='blue', alpha=0.3)
sns.kdeplot(x=ps_treated, fill=True, label='Received (Raw)', color='red', alpha=0.3)
plt.title("Before Weighting: Selection Bias")
plt.xlabel("Propensity Score")
plt.legend(loc='upper center')

# Plot 2: Weighted (Pseudo-Population)
# FIX: Added 'x=' before ps_control and ps_treated
plt.subplot(1, 2, 2)
sns.kdeplot(x=ps_control, weights=w_control, fill=True, label='Withheld (Weighted)', color='blue', alpha=0.3)
sns.kdeplot(x=ps_treated, weights=w_treated, fill=True, label='Received (Weighted)', color='red', alpha=0.3)
plt.title(f"After Weighting: Pseudo-Population\nOverlap Coeff: {overlap_score:.2f}")
plt.xlabel("Propensity Score")
plt.legend(loc='upper center')

plt.tight_layout()
plt.show()

In [None]:
# --- NEW DIAGNOSTIC CELL: Comprehensive Structural Positivity Test ---
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

print("Running Comprehensive Structural Positivity Test...")

# 1. Define the Full Sets (From your provided list)
CONTRAST_CT = {4139745, 21492176, 4335400, 3047782, 4327032, 3013610, 36713226, 3053128, 4252907, 3019625}
CONTRAST_MRI = {4335399, 4161393, 4202274, 4197203, 36717294, 45765683, 37117806, 37109194, 37109196}
NON_CONTRAST_CT = {37109313, 3049940, 37117305, 3047921, 36713200, 3018999, 40771605, 36713202, 3035568}
NON_CONTRAST_MRI = {37109312, 36713204, 36713045, 36713262, 3024397, 36713243, 3053040, 37109329, 42535581, 42535582}

# 2. Map Every ID to a Human-Readable Label & Category
# We create a lookup dictionary for plotting
proc_meta = {}

def classify_proc(pid):
    if pid in CONTRAST_CT: return 'CT (Contrast)', f"CT_Con_{pid}"
    if pid in NON_CONTRAST_CT: return 'CT (Withheld)', f"CT_Non_{pid}"
    if pid in CONTRAST_MRI: return 'MRI (Contrast)', f"MRI_Con_{pid}"
    if pid in NON_CONTRAST_MRI: return 'MRI (Withheld)', f"MRI_Non_{pid}"
    return 'Other', str(pid)

# Apply to the dataframe
df_viz = df_final.copy()
# Apply classification row-wise (vectorized map is faster if we pre-calculate)
# Since we have sets, we can do this efficiently:
df_viz['modality_group'] = 'Other'
df_viz.loc[df_viz['procedure_concept_id'].isin(CONTRAST_CT), 'modality_group'] = 'CT (Contrast)'
df_viz.loc[df_viz['procedure_concept_id'].isin(NON_CONTRAST_CT), 'modality_group'] = 'CT (Withheld)'
df_viz.loc[df_viz['procedure_concept_id'].isin(CONTRAST_MRI), 'modality_group'] = 'MRI (Contrast)'
df_viz.loc[df_viz['procedure_concept_id'].isin(NON_CONTRAST_MRI), 'modality_group'] = 'MRI (Withheld)'

df_viz['proc_label'] = df_viz['procedure_concept_id'].astype(str)

# 3. VIZ 1: The Macro View (Modality Overlap)
plt.figure(figsize=(10, 6))
sns.violinplot(data=df_viz, x='modality_group', y='ps', palette="muted", inner="quartile")
plt.title("Macro Check: Propensity Score by Modality & Status")
plt.ylabel("Propensity Score (Prob. of Contrast)")
plt.xlabel("Modality Group")
plt.axhline(0.5, color='gray', linestyle=':')
plt.show()

# 4. VIZ 2: The Micro View (All Procedures Sorted)
# We calculate the median PS for each procedure to sort them
proc_stats = df_viz.groupby('procedure_concept_id')['ps'].median().sort_values()
sorted_procs = proc_stats.index.astype(str).tolist()

# Split into CT and MRI for readability
ct_ids = [str(x) for x in CONTRAST_CT.union(NON_CONTRAST_CT)]
mri_ids = [str(x) for x in CONTRAST_MRI.union(NON_CONTRAST_MRI)]

fig, axes = plt.subplots(2, 1, figsize=(16, 14), sharey=True)

# Plot CT
sns.boxplot(
    data=df_viz[df_viz['proc_label'].isin(ct_ids)], 
    x='proc_label', y='ps', hue='contrast_received',
    order=[p for p in sorted_procs if p in ct_ids],
    ax=axes[0], palette={0:'blue', 1:'red'}, showfliers=False
)
axes[0].set_title("CT Procedures: Propensity Distribution (Sorted by Median PS)")
axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=90)
axes[0].grid(True, axis='y', alpha=0.3)
axes[0].legend(loc='upper left', title='Treatment')

# Plot MRI
sns.boxplot(
    data=df_viz[df_viz['proc_label'].isin(mri_ids)], 
    x='proc_label', y='ps', hue='contrast_received',
    order=[p for p in sorted_procs if p in mri_ids],
    ax=axes[1], palette={0:'blue', 1:'red'}, showfliers=False
)
axes[1].set_title("MRI Procedures: Propensity Distribution (Sorted by Median PS)")
axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=90)
axes[1].grid(True, axis='y', alpha=0.3)
axes[1].get_legend().remove()

plt.tight_layout()
plt.show()

# 5. The "Red Flag" Report
print("\n--- STRUCTURAL ZERO WARNINGS ---")
print("Identifying procedures with extreme Propensity Scores (PS < 0.10 or PS > 0.90)")
print("These indicate Protocol-Based decisions (Bias) rather than Clinical-Based decisions (Confounding).")
print("-" * 80)

low_ps = proc_stats[proc_stats < 0.10]
high_ps = proc_stats[proc_stats > 0.90]

if len(low_ps) > 0:
    print(f"⚠️  PROTOCOL EXCLUSION (Almost Never Get Contrast): {len(low_ps)} procedures")
    print(f"    IDs: {low_ps.index.tolist()}")
    print("    -> These likely represent non-contrast protocols (e.g., C-Spine, Stroke, Stones).")
else:
    print("✅  No extreme 'Never Contrast' protocols found (Median PS > 0.10).")

if len(high_ps) > 0:
    print(f"⚠️  PROTOCOL INCLUSION (Almost Always Get Contrast): {len(high_ps)} procedures")
    print(f"    IDs: {high_ps.index.tolist()}")
    print("    -> These likely represent obligatory contrast protocols (e.g., Tumor Staging, Angio).")
else:
    print("✅  No extreme 'Always Contrast' protocols found (Median PS < 0.90).")

In [None]:
# --- CELL 8: LOVE PLOT ---

def get_smd(X, t, w):
    # Weighted Means
    mu_1, var_1 = sparse_weighted_mean_var(X[t==1], w[t==1])
    mu_0, var_0 = sparse_weighted_mean_var(X[t==0], w[t==0])
    
    # Pooled SD
    pooled_sd = np.sqrt((var_1 + var_0) / 2)
    pooled_sd[pooled_sd == 0] = 1e-6 # Avoid div/0
    
    return np.abs((mu_1 - mu_0) / pooled_sd)

print("Calculating Balance...")
# Unweighted
smd_unw = get_smd(X_final, T_final, np.ones(len(T_final)))
# Weighted
smd_w = get_smd(X_final, T_final, df_final['iptw'].values)

# Plot top 50 imbalanced features
top_idx = np.argsort(smd_unw)[-50:]
plt.figure(figsize=(8, 10))
plt.scatter(smd_unw[top_idx], range(50), label='Unadjusted', alpha=0.6)
plt.scatter(smd_w[top_idx], range(50), label='Adjusted', alpha=0.8)
plt.axvline(0.1, color='r', linestyle='--')
plt.title("Covariate Balance (Top 50 Variates)")
plt.xlabel("Absolute SMD")
plt.legend()
plt.show()

In [None]:
# =============================================================================
# CELL 9: Cross-Fitted SuperLearner AIPW Engine (Sklearn Stacking)
# =============================================================================
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier, RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from joblib import Parallel, delayed
from scipy.stats import norm
import numpy as np
import scipy.sparse as sp

print("\n--- CELL 9: Super Learner AIPW Engine (Sklearn Stacking) ---")

def calculate_e_value(rr_or_hr):
    """Calculates E-Value for unmeasured confounding."""
    if rr_or_hr <= 1: return 1
    return rr_or_hr + np.sqrt(rr_or_hr * (rr_or_hr - 1))

def get_super_learner(n_jobs_inner=1):
    """
    Returns an sklearn StackingClassifier designed for sparse data.
    Combines:
      1. Lasso (Linear, good for rare codes)
      2. Random Forest (Non-linear, good for interactions)
    """
    # 1. Base Learners
    # We use 'liblinear' for Lasso as it handles sparse matrices efficiently.
    # We limit RF depth slightly to prevent memory explosion with 70k features.
    estimators = [
        ('lasso', LogisticRegression(penalty='l1', solver='liblinear', C=0.2, 
                                     class_weight='balanced', max_iter=2000)),
        ('rf', RandomForestClassifier(n_estimators=100, max_depth=20, 
                                      class_weight='balanced', n_jobs=n_jobs_inner))
    ]
    
    # 2. The Stack (Meta-Learner)
    # Uses internal CV to learn how to best combine Lasso and RF predictions
    # passthrough=False means the meta-learner only sees the predictions of base learners
    stack = StackingClassifier(
        estimators=estimators,
        final_estimator=LogisticRegression(), 
        cv=3,  # Internal 3-fold CV to train the combiner
        n_jobs=n_jobs_inner,
        passthrough=False
    )
    return stack

def _fit_super_learner_fold(train_idx, eval_idx, X_sparse, T_full, Y_full):
    """
    Worker function for a single fold of Cross-Fitting.
    Fits the Super Learner on training data, predicts on eval data.
    """
    # 1. Slice Data (Sparse Slicing)
    X_train, X_eval = X_sparse[train_idx], X_sparse[eval_idx]
    T_train, T_eval = T_full[train_idx], T_full[eval_idx]
    Y_train, Y_eval = Y_full[train_idx], Y_full[eval_idx]
    
    # Note: We use n_jobs_inner=4 for models to speed up RF training inside the fold.
    # Total threads = n_folds (outer) * n_jobs_inner.
    
    # --- 2. Propensity Score Model (Pi) ---
    sl_ps = get_super_learner(n_jobs_inner=4)
    sl_ps.fit(X_train, T_train)
    pi_hat = sl_ps.predict_proba(X_eval)[:, 1]
    # Clip for stability (AIPW requirement)
    pi_hat = np.clip(pi_hat, 0.025, 0.975)
    
    # --- 3. Outcome Models (Mu) ---
    # We must separate T=0 and T=1 to learn the counterfactuals
    mask0 = (T_train == 0)
    mask1 = (T_train == 1)
    
    # Mu0 (Outcome if No Contrast)
    sl_mu0 = get_super_learner(n_jobs_inner=4)
    sl_mu0.fit(X_train[mask0], Y_train[mask0])
    mu0_hat = sl_mu0.predict_proba(X_eval)[:, 1]
    
    # Mu1 (Outcome if Contrast)
    sl_mu1 = get_super_learner(n_jobs_inner=4)
    sl_mu1.fit(X_train[mask1], Y_train[mask1])
    mu1_hat = sl_mu1.predict_proba(X_eval)[:, 1]
    
    # --- 4. Compute Efficient Influence Function (EIF) ---
    # Formula: (mu1 - mu0) + T(Y-mu1)/pi - (1-T)(Y-mu0)/(1-pi)
    term1 = mu1_hat - mu0_hat # Risk Difference
    term2 = (T_eval * (Y_eval - mu1_hat)) / pi_hat
    term3 = ((1 - T_eval) * (Y_eval - mu0_hat)) / (1 - pi_hat)
    eif_chunk = term1 + term2 - term3
    
    return eval_idx, mu0_hat, mu1_hat, pi_hat, eif_chunk

def run_cross_fitted_aipw(X_sparse_matrix, T_full, Y_full, n_folds=5):
    """
    Main driver for K-Fold Cross-Fitting.
    Returns: stats (dict), predictions (dict)
    """
    kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    n = len(T_full)
    
    # Storage arrays
    mu0_hat = np.zeros(n)
    mu1_hat = np.zeros(n)
    pi_hat  = np.zeros(n)
    eif_val = np.zeros(n)
    
    print(f"  Running {n_folds}-Fold Super Learner (StackingClassifier) in PARALLEL...")
    print(f"  (This utilizes high compute: Lasso + Random Forest Stacking)")

    # --- Parallel Execution ---
    # Runs the 5 folds simultaneously.
    results = Parallel(n_jobs=n_folds, verbose=10)(
        delayed(_fit_super_learner_fold)(train_idx, eval_idx, X_sparse_matrix, T_full, Y_full)
        for train_idx, eval_idx in kf.split(X_sparse_matrix, T_full)
    )
    
    # --- Aggregate Results ---
    for eval_idx, mu0_c, mu1_c, pi_c, eif_c in results:
        mu0_hat[eval_idx] = mu0_c
        mu1_hat[eval_idx] = mu1_c
        pi_hat[eval_idx]  = pi_c
        eif_val[eval_idx] = eif_c

    # --- Statistics & Inference ---
    ate = np.mean(eif_val)
    se = np.std(eif_val) / np.sqrt(n)
    # Z-test P-value
    p_value = 2 * (1 - norm.cdf(np.abs(ate / se))) if se > 0 else 0.0
    
    # Risk Estimates (Population Averages)
    risk_1 = np.mean(mu1_hat)
    risk_0 = np.mean(mu0_hat)
    rr = risk_1 / risk_0 if risk_0 > 0 else 0.0
    
    # Effective Sample Size (Kish)
    weights = np.where(T_full==1, 1/pi_hat, 1/(1-pi_hat))
    ess = (np.sum(weights) ** 2) / np.sum(weights ** 2)
    
    # E-Value Calculation
    # Maps RR < 1 to equivalent risk increase for formula
    e_calc_rr = 1/rr if (rr < 1 and rr > 0) else rr
    e_val = calculate_e_value(e_calc_rr) if not np.isnan(rr) else 1.0

    stats = {
        'ATE': ate, 'SE': se, 'P_Value': p_value,
        'Risk_1': risk_1, 'Risk_0': risk_0, 
        'RR': rr, 'ESS': ess, 'E_Value': e_val,
        'CI_Lower': ate - 1.96*se, 'CI_Upper': ate + 1.96*se
    }
    
    predictions = {
        'mu0': mu0_hat, 'mu1': mu1_hat, 'pi': pi_hat, 'eif': eif_val
    }
    
    return stats, predictions

print("Super Learner Engine (Sklearn) Ready.")

In [None]:
# 1. Define Policies
# Policies act on the 'df_final' cohort (covariates) to output a decision vector d ∈ {0, 1}^n
# 0 = Withhold all contrast, 1 = Give contrast
#
# NOTE: ACR 2024 guidance is about *which* contrast (iodinated vs GBCA),
# not "withhold vs give". The eGFR-based rules below are *toy* withholding
# rules for counterfactual exploration, not literal guideline implementations.

def policy_current(df):
    """Current Practice: The observed decision."""
    return df['contrast_received'].values

def policy_always(df):
    """Extreme policy: Always give contrast (100% contrast use)."""
    return np.ones(len(df), dtype=int)

def policy_never(df):
    """Extreme policy: Never give contrast (0% contrast use)."""
    return np.zeros(len(df), dtype=int)

def policy_egfr_rule_30(df):
    """Toy rule: Withhold contrast if eGFR < 30 (egfr_cat == 0), otherwise give."""
    return (df['egfr_cat'] != 0).astype(int).values

def policy_egfr_rule_45(df):
    """Toy rule: Withhold contrast if eGFR < 45 (egfr_cat in {0,1}), otherwise give."""
    return (~df['egfr_cat'].isin([0, 1])).astype(int).values

policies = {
    'Current Practice': policy_current,
    'Always Contrast (100%)': policy_always,
    'Never Contrast (0%)': policy_never,
    'eGFR Rule: withhold if <30': policy_egfr_rule_30,
    'eGFR Rule: withhold if <45': policy_egfr_rule_45,
}


# 2. Execution Loop
# We evaluate on the primary outcome (AKI_30)
outcome_name = 'AKI_30'
col_date = f"date_{outcome_name}"
print(f"Evaluating policies for outcome: {outcome_name}")

# Prepare Arrays
T_vec = df_final['contrast_received'].values
# Define Y (Binary outcome within 30 days)
Y_vec = ((df_final[col_date] - df_final['index_date']).dt.days <= 30).astype(int).values

# Step A: Get Nuisance Parameters (Mu0, Mu1, Pi) via Cross-Fitting
# We use the X_final sparse matrix
stats, preds = run_cross_fitted_aipw(X_final, T_vec, Y_vec, n_folds=5)

mu1 = preds['mu1']
mu0 = preds['mu0']
pi  = preds['pi']

# Step B: Evaluate Policies
policy_results = []

for name, func in policies.items():
    d_vec = func(df_final) # 0/1 vector
    
    # Calculate Policy Value V(d) using AIPW (Doubly Robust) estimator adapted for policies
    # V_dr(d) = Mean( d/pi * (Y - mu1) + mu1 ) if d=1, ... logic generalizes:
    # Estimate Y(1) for everyone, Y(0) for everyone, then mix based on d.
    # DR Score for individual i:
    # Gamma_1 = mu1 + T/pi * (Y - mu1)
    # Gamma_0 = mu0 + (1-T)/(1-pi) * (Y - mu0)
    # V(d) = d * Gamma_1 + (1-d) * Gamma_0
    
    gamma_1 = mu1 + (T_vec / pi) * (Y_vec - mu1)
    gamma_0 = mu0 + ((1 - T_vec) / (1 - pi)) * (Y_vec - mu0)
    
    # Individual policy estimates
    psi_i = d_vec * gamma_1 + (1 - d_vec) * gamma_0
    
    risk_val = np.mean(psi_i)
    risk_se  = np.std(psi_i) / np.sqrt(len(psi_i))
    
    # Withholding Rate (W)
    withhold_rate = np.mean(1 - d_vec)
    
    policy_results.append({
        'Policy': name,
        'Risk': risk_val,
        'Risk_SE': risk_se,
        'Withholding': withhold_rate
    })

df_pol_res = pd.DataFrame(policy_results)
print(df_pol_res.round(5))

# 3. Visualization (XY Plot)
plt.figure(figsize=(10, 7))

# Plot points with error bars
for i, row in df_pol_res.iterrows():
    plt.errorbar(
        x=row['Withholding'],
        y=row['Risk'],
        yerr=1.96 * row['Risk_SE'],
        fmt='o',
        markersize=10,
        capsize=5,
        label=row['Policy']
    )
    # Label text offset
    plt.text(row['Withholding'], row['Risk'] + 0.0005, f"  {row['Policy']}", fontsize=9)

# Formatting
plt.title(f"Policy Frontier: Harm vs. Withholding\nOutcome: {outcome_name} (30-Day Risk)")
plt.xlabel("Proportion of Patients Withheld Contrast")
plt.ylabel(f"Estimated Risk of {outcome_name}")
plt.grid(True, linestyle=':', alpha=0.6)
plt.legend(loc='best')
plt.tight_layout()
plt.show()
# Print AIPW stats for the contrast vs no-contrast comparison
print(f"ATE (Contrast vs No-Contrast): {stats['ATE']:.4f} "
      f"[{stats['CI_Lower']:.4f}, {stats['CI_Upper']:.4f}], p={stats['P_Value']:.3g}")
print(f"Risk_1 (E[Y | do(T=1)]): {stats['Risk_1']:.4f}")
print(f"Risk_0 (E[Y | do(T=0)]): {stats['Risk_0']:.4f}")
print(f"RR (Contrast vs No-Contrast): {stats['RR']:.3f}")
print(f"E-Value (Contrast vs No-Contrast): {stats['E_Value']:.2f}")
print(f"AIPW Effective Sample Size (AIPW weights): {stats['ESS']:.1f}")

# Current-practice policy value from the policy frontier
cp_row = df_pol_res[df_pol_res['Policy'] == 'Current Practice'].iloc[0]
print(f"Current Practice Policy Risk (DR): {cp_row['Risk']:.4f}")


In [None]:
# --- CELL 11: SURVIVAL PLOTS ---

outcome = 'AKI_30'
print(f"Diagnostic Plots for {outcome} (Observed Trial)...")

df_viz = df_final.copy()

# Time-to-event for AKI within 30 days
event_date = df_viz['date_AKI_30']
idx_date = df_viz['index_date']

# If no event, censor at 30 days
event_date_filled = event_date.fillna(idx_date + pd.Timedelta(days=30))
t_days = (event_date_filled - idx_date).dt.days
t_days = t_days.clip(lower=0, upper=30)

df_viz['T_viz'] = t_days
df_viz['E_viz'] = ((event_date.notnull()) &
                   ((event_date - idx_date).dt.days <= 30)).astype(int)

# 1. KM Curves with IPTW
kmf0 = KaplanMeierFitter()
kmf1 = KaplanMeierFitter()

plt.figure(figsize=(10, 6))
kmf0.fit(
    df_viz[df_viz['contrast_received'] == 0]['T_viz'],
    df_viz[df_viz['contrast_received'] == 0]['E_viz'],
    weights=df_viz[df_viz['contrast_received'] == 0]['iptw'],
    label='Withheld'
)
kmf1.fit(
    df_viz[df_viz['contrast_received'] == 1]['T_viz'],
    df_viz[df_viz['contrast_received'] == 1]['E_viz'],
    weights=df_viz[df_viz['contrast_received'] == 1]['iptw'],
    label='Contrast'
)

kmf0.plot_survival_function()
kmf1.plot_survival_function()
plt.title(f"Adjusted Survival Curve: {outcome}")
plt.xlabel("Days since index")
plt.ylabel("Survival probability (no AKI)")
plt.show()

# 2. Cox model & PH diagnostics
cph = CoxPHFitter(penalizer=0.1)
cph.fit(
    df_viz[['T_viz', 'E_viz', 'contrast_received', 'iptw']],
    duration_col='T_viz',
    event_col='E_viz',
    weights_col='iptw'
)
cph.check_assumptions(
    df_viz[['T_viz', 'E_viz', 'contrast_received', 'iptw']],
    show_plots=True
)

In [None]:
# --- CELL 12: ECOLOGICAL CONFOUNDING CHECK (CORRECTED) ---
import seaborn as sns
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

print("\n--- ECOLOGICAL ANALYSIS (INSTRUMENTAL VARIABLE CHECK) ---")

# 1. Aggregate by Site (Zip)
site_stats = df_final.groupby('zip_code').agg({
    'contrast_received': 'mean', # Observed Rate
    'ps': 'mean',                # Expected Rate (Patient Risk Profile)
    'iptw': 'count'              # Volume
}).rename(columns={'iptw': 'volume'})

# Filter for meaningful sites (e.g., >20 patients)
site_stats = site_stats[site_stats['volume'] > 20].copy()

# 2. Calculate "Practice Preference" (Instrument)
site_stats['practice_preference'] = site_stats['contrast_received'] - site_stats['ps']

# 3. Calculate Site-Level Outcome Rate (Raw)
aki_events = ((df_final['date_AKI_30'] - df_final['index_date']).dt.days <= 30)
df_final['has_aki_30'] = aki_events.fillna(False).astype(int)
outcome_stats = df_final.groupby('zip_code')['has_aki_30'].mean()

eco_df = site_stats.merge(outcome_stats, left_index=True, right_index=True)

# --- FIX: Convert to standard numpy floats to prevent Seaborn TypeError ---
eco_df['practice_preference'] = eco_df['practice_preference'].astype(float)
eco_df['has_aki_30'] = eco_df['has_aki_30'].astype(float)
eco_df['volume'] = eco_df['volume'].astype(float)

# 4. Correlation Test
corr, p_val = pearsonr(eco_df['practice_preference'], eco_df['has_aki_30'])

print(f"Number of Sites Analyzed: {len(eco_df)}")
print(f"Correlation (r): {corr:.3f}")
print(f"P-Value: {p_val:.4f}")

# Interpretation
if p_val < 0.05 and abs(corr) > 0.2:
    print("WARNING: Significant Ecological Correlation detected.")
    print("This suggests Unmeasured Confounding at the site level.")
else:
    print("PASS: No significant correlation between site preference and outcomes.")
    print("Variation in withholding appears random regarding unmeasured site quality.")

# 5. The Quadrant Plot
plt.figure(figsize=(10, 8))
sns.scatterplot(
    data=eco_df, 
    x='practice_preference', 
    y='has_aki_30', 
    size='volume', 
    sizes=(20, 600),
    alpha=0.6,
    hue='practice_preference',
    palette='vlag'
)

# Add Quadrant Lines
mean_pref = eco_df['practice_preference'].mean()
mean_out = eco_df['has_aki_30'].mean()
plt.axhline(mean_out, linestyle='--', color='gray', alpha=0.7)
plt.axvline(mean_pref, linestyle='--', color='gray', alpha=0.7)

# Labels
plt.title("Ecological Analysis: Site Preference vs. Outcomes")
plt.xlabel("Site Preference (Actual - Expected)\n<-- Conservative (Withholds) | Aggressive (Gives) -->")
plt.ylabel("Raw AKI Rate (30-Day)")

# Quadrant Annotations
plt.text(mean_pref + 0.02, mean_out - 0.005, "Aggressive & Low Risk", color='green', fontsize=9)
plt.text(mean_pref - 0.08, mean_out + 0.005, "Conservative & High Risk", color='red', fontsize=9)

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title="Preference")
plt.tight_layout()
plt.show()

In [None]:
# --- CELL 10.5: RISK-BASED CONFOUNDING-BY-INDICATION CHECK ---
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

print("\n--- RISK-BASED CONFOUNDING CHECK (mu0 vs Treatment/Propensity) ---")

# Assemble diagnostic frame
df_diag = pd.DataFrame({
    "mu_0": preds["mu0"],           # Estimated risk if untreated
    "mu_1": preds["mu1"],           # Estimated risk if treated
    "pi": preds["pi"],              # Propensity score from nuisance model
    "contrast_received": T_vec      # Actual treatment (0/1)
}, index=df_final.index)

# 1. Correlation: baseline risk vs propensity score
mask_pi = df_diag["mu_0"].notna() & df_diag["pi"].notna()
r_mu0_pi, p_mu0_pi = pearsonr(
    df_diag.loc[mask_pi, "mu_0"],
    df_diag.loc[mask_pi, "pi"]
)

# 2. Correlation: baseline risk vs observed treatment
mask_t = df_diag["mu_0"].notna() & df_diag["contrast_received"].notna()
r_mu0_t, p_mu0_t = pearsonr(
    df_diag.loc[mask_t, "mu_0"],
    df_diag.loc[mask_t, "contrast_received"]
)

print(f"Corr(mu_0, pi):               r = {r_mu0_pi:.3f}, p = {p_mu0_pi:.2e}")
print(f"Corr(mu_0, contrast_received): r = {r_mu0_t:.3f}, p = {p_mu0_t:.2e}")

if (r_mu0_pi < -0.2 and p_mu0_pi < 0.05) or (r_mu0_t < -0.2 and p_mu0_t < 0.05):
    print("SIGNAL: Higher baseline risk associated with lower treatment probability (confounding by indication).")
else:
    print("NO STRONG SIGNAL: Baseline risk not clearly linked to treatment probability.")

# 3. Visualization: Baseline risk vs propensity
sample_df = df_diag.sample(n=min(10000, len(df_diag)), random_state=42)

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=sample_df,
    x="mu_0",
    y="pi",
    hue="contrast_received",
    alpha=0.3,
    s=20
)
sns.regplot(
    data=sample_df,
    x="mu_0",
    y="pi",
    scatter=False,
    lowess=True
)

plt.title("Baseline Risk vs Propensity to Receive Contrast")
plt.xlabel("Estimated Risk if Untreated (mu_0)")
plt.ylabel("Propensity Score (pi)")
plt.legend(title="Contrast Received", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


In [None]:
# =============================================================================
# CELL 13: Method Validation (Negative Controls) & Calibration
# =============================================================================
print("\n--- CELL 13: Negative Control Validation ---")

nc_res_list = []

# Loop through Negative Controls
for nc_name in NEGATIVE_CONTROLS.keys():
    col_nc = f"date_{nc_name}"
    if col_nc not in df_final.columns: continue
    
    # Define Y_nc
    Y_nc = ((df_final[col_nc] - df_final['index_date']).dt.days <= 30).astype(int).values
    T_nc = df_final['contrast_received'].values
    
    # Run simplified AIPW (or just weighted Reg for speed)
    # We use Weighted Logistic Regression here for speed/stability on NCs
    try:
        # Unadjusted
        lr_unadj = LogisticRegression(solver='lbfgs')
        lr_unadj.fit(T_nc.reshape(-1, 1), Y_nc)
        or_unadj = np.exp(lr_unadj.coef_[0][0])
        
        # Adjusted (Using IPTW from Cell 7 if available, or just re-fit simple weights)
        # We'll use the PS from the main model if available in df_final
        if 'iptw' in df_final.columns:
            weights = df_final['iptw'].values
        else:
            weights = np.ones_like(T_nc)
            
        lr_adj = LogisticRegression(solver='lbfgs')
        lr_adj.fit(T_nc.reshape(-1, 1), Y_nc, sample_weight=weights)
        or_adj = np.exp(lr_adj.coef_[0][0])
        
        # CI for Adjusted
        # (Approximate SE via Inverse Hessian or bootstrap - skipping for speed in summary)
        
        nc_res_list.append({
            'Outcome': nc_name,
            'OR_Unadj': or_unadj,
            'OR_Adj': or_adj,
            'LogOR_Adj': lr_adj.coef_[0][0]
        })
    except Exception as e:
        print(f"Skipping {nc_name}: {e}")

df_nc = pd.DataFrame(nc_res_list)

# Visualization: Calibration Funnel / Scatter
if not df_nc.empty:
    print(df_nc.round(3))
    
    plt.figure(figsize=(12, 5))
    
    # Plot 1: Unadjusted vs Adjusted
    plt.subplot(1, 2, 1)
    plt.scatter(np.log(df_nc['OR_Unadj']), np.log(df_nc['OR_Adj']), alpha=0.7, c='purple')
    plt.axhline(0, color='black', linestyle='--')
    plt.axvline(0, color='black', linestyle='--')
    plt.plot([-2, 2], [-2, 2], 'k:', alpha=0.3)
    plt.title("Bias Correction: Unadjusted vs Adjusted Estimates")
    plt.xlabel("Unadjusted Log OR")
    plt.ylabel("Adjusted Log OR")
    plt.grid(True)
    
    # Plot 2: Null Distribution
    plt.subplot(1, 2, 2)
    sns.histplot(df_nc['LogOR_Adj'], kde=True, bins=10)
    plt.axvline(0, color='red', linestyle='--', label='Null')
    plt.title("Distribution of Negative Control Estimates (Should center at 0)")
    plt.xlabel("Adjusted Log OR")
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Summary Metric
    rmse = np.sqrt(np.mean(df_nc['LogOR_Adj']**2))
    print(f"Calibration RMSE (Distance from Null): {rmse:.4f}")
    if rmse < 0.1: print("VALIDATION: Excellent Calibration.")
    elif rmse < 0.2: print("VALIDATION: Acceptable Calibration.")
    else: print("VALIDATION: WARNING - Possible Residual Confounding.")

# Save Results
if 'df_pol_res' in locals():
    df_pol_res.to_csv("final_policy_results.csv", index=False)
    print("Saved final results to CSV.")

In [None]:
# =============================================================================
# NEW CELL: Subgroup Policy Analysis (eGFR Categories)
# =============================================================================
print("\n--- NEW CELL: Subgroup Analysis by eGFR Category ---")

# Define Subgroups based on egfr_cat
# 0=<30, 1=30-44, 2=45-59, 3=60+
subgroups = {
    'eGFR < 30': [0],
    'eGFR 30-44': [1],
    'eGFR 45-59': [2],
    'eGFR >= 60': [3]
}

# We perform the analysis on the subsets
# Note: For valid inference, we calculate risks WITHIN the subgroup
# utilizing the predictions from the global model (Predictions are conditional on X)
# V(d | S) = E[ Psi_i | i in S ]

plt.figure(figsize=(15, 10))

plot_idx = 1
for label, cats in subgroups.items():
    # Identify Indices
    mask = df_final['egfr_cat'].isin(cats)
    indices = df_final[mask].index
    
    if len(indices) < 50:
        print(f"Skipping {label}: N={len(indices)} (Too small)")
        continue
        
    print(f"Analyzing Subgroup: {label} (N={len(indices)})")
    
    # Filter global predictions (preds defined in Cell 10/9)
    # We use the integer locations of the mask
    locs = np.where(mask.values)[0]
    
    gamma_1_sub = preds['mu1'][locs] + (T_vec[locs]/preds['pi'][locs]) * (Y_vec[locs] - preds['mu1'][locs])
    gamma_0_sub = preds['mu0'][locs] + ((1-T_vec[locs])/(1-preds['pi'][locs])) * (Y_vec[locs] - preds['mu0'][locs])
    
    # Evaluate Policies on this subgroup
    res_sub = []
    for pol_name, func in policies.items():
        # Get decision for this subgroup
        # We need to slice the dataframe for the function input
        df_sub = df_final.iloc[locs]
        d_sub = func(df_sub)
        
        # Calculate Value
        psi_sub = d_sub * gamma_1_sub + (1 - d_sub) * gamma_0_sub
        
        risk = np.mean(psi_sub)
        se = np.std(psi_sub) / np.sqrt(len(psi_sub))
        withhold = np.mean(1 - d_sub)
        
        res_sub.append({'Policy': pol_name, 'Risk': risk, 'SE': se, 'W': withhold})
        
    df_res_sub = pd.DataFrame(res_sub)
    
    # Plot
    plt.subplot(2, 2, plot_idx)
    for i, row in df_res_sub.iterrows():
        plt.errorbar(row['W'], row['Risk'], yerr=1.96*row['SE'], fmt='o', label=row['Policy'], capsize=5)
        plt.text(row['W'], row['Risk'], f" {row['Policy']}", fontsize=8)
        
    plt.title(f"Subgroup: {label}")
    plt.xlabel("Withholding Rate")
    plt.ylabel(f"{outcome_name} Risk")
    plt.grid(True, alpha=0.3)
    if plot_idx == 1: plt.legend(loc='upper left', fontsize=8)
    
    plot_idx += 1

plt.tight_layout()
plt.show()